diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..1595fa33ca90b4c1b6862e5fc8f5bbd844e90f5c --- /dev/null +++ b/.clang-format @@ -0,0 +1,85 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +--- +# 详细配置说明 https://clang.llvm.org/docs/ClangFormatStyleOptions.html +--- +Language: Cpp +BasedOnStyle: Google +# public等标识符不缩进 +AccessModifierOffset: -4 +# 限制行宽120字符 +ColumnLimit: 120 +# 4空格缩进 +IndentWidth: 4 +# 不使用tab +UseTab: Never +# 二元运算符换行时对齐 +AlignOperands: Align +# 参数换行时对齐 +AlignAfterOpenBracket: Align +# 行末注释对齐 +AlignTrailingComments: true +DerivePointerAlignment: false +# 引用和指针左对齐 +PointerAlignment: Left +AllowAllParametersOfDeclarationOnNextLine: false +AllowAllArgumentsOnNextLine: false +AllowShortBlocksOnASingleLine: Empty +AllowShortCaseLabelsOnASingleLine: false +AllowShortEnumsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AllowShortLambdasOnASingleLine: Inline +# Break after return type automatically +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakBeforeMultilineStrings: false +# 允许参数部分换行 +BinPackArguments: true +BinPackParameters: true +BreakBeforeBraces: Custom +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + # 只有函数括号另起一行 + AfterFunction: true + AfterNamespace: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +# 二元运算符换行时 运算符在第一行末尾 +BreakBeforeBinaryOperators: None +# 三元运算符换行时 运算符在下一行 +BreakBeforeTernaryOperators: true +# 构造函数初始化列表冒号在换行后 逗号在换行前 +BreakConstructorInitializers: BeforeColon +BreakStringLiterals: true +CompactNamespaces: false +# 初始化要么一行 要么每个一行 +PackConstructorInitializers: CurrentLine +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +# 使用cpp11统一初始化风格 +Cpp11BracedListStyle: true +DisableFormat: false +FixNamespaceComments: true +# 返回值类型声明后换行时不缩进 +IndentWrappedFunctionNames: false +Standard: Latest diff --git a/README.md b/README.md index 4a2c310b9054468e933358d1526209aae245bba9..5a2d9c0396fb03f0464f6d72941ba633f41d5d54 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ mxRec作为面向互联网市场搜索推荐广告的应用使能SDK产品,对 ## 安装方式 -安装前,请参考《CANN 软件安装指南CANN 软件安装指南》安装CANN开发套件软件包和TensorFlow适配昇腾插件。 +安装前,请参考[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha003/softwareinst/instg/instg_0022.html)安装CANN开发套件软件包和TensorFlow适配昇腾插件。 CANN软件提供进程级环境变量设置脚本,供用户在进程中引用,以自动完成环境变量设置。用户进程结束后自动失效。可在程序启动的Shell脚本中使用如下命令设置CANN的相关环境变量,也可通过命令行执行如下命令(以root用户默认安装路径“/usr/local/Ascend”为例): ```shell @@ -24,7 +24,7 @@ source /usr/local/Ascend/tfplugin/set_env.sh 安装依赖,若未构建镜像,直接在物理机上进行开发,则须安装以下Python依赖 ```shell -pip3 install numpy decorator sympy==1.4 cffi==1.12.3 pyyaml pathlib2 grpcio grpcio-tools protobuf==3.20.0 scipy requests mpi4py easydict scikit-learn==0.20.0 attrs +pip3 install numpy decorator sympy==1.4 cffi==1.12.3 pyyaml pathlib2 pandas grpcio grpcio-tools protobuf==3.20.0 scipy requests mpi4py easydict scikit-learn==0.20.0 attrs ``` horovod依赖安装前需配置“HOROVOD_WITH_MPI”、“HOROVOD_WITH_TENSORFLOW”,依赖安装命令参考如下。 @@ -58,37 +58,95 @@ bash run.sh - CMake 3.20.6 开源依赖: -- pybind11 v2.10.3 -- securec -- openmpi 4.1.1: 请参考软件文档在编译环境完成安装 +- [pybind11 v2.10.3](https://github.com/pybind/pybind11/archive/refs/tags/v2.10.3.zip) +- [securec](https://github.com/huaweicloud/huaweicloud-sdk-c-obs/archive/refs/tags/v3.23.9.zip) +- [openmpi 4.1.5](https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz): 请参考软件文档在编译环境完成安装 - tensorflow 1.15/2.6.5:根据实际需求选择对应版本 -pybind11的压缩包放在与MxRec代码同级的opensource/opensource目录下,如果没有opensource目录,则需要在MxRec同级的目录下手动创建opensource/opensource目录。然后将pybind11的压缩包放在opensource/opensource目录下。解压压缩包,并且将解压之后的压缩包改名为pybind11。 +将pybind11和securec的压缩包放在与mxRec代码同级的opensource目录下,并且将其分别更名为pybind11-2.10.3.zip、huaweicloud-sdk-c-obs-3.23.9.zip。如果没有opensource目录,则需要在mxRec同级的目录下手动创建opensource目录,然后将pybind11和securec的压缩包放在opensource目录下。 -securec是华为开源的安全函数库。下载后: -1. 将platform下的eSDK_LogAPI_V2.1.10文件夹删除 -2. 将platform下的huaweisecurec改名为securec -3. 在securec文件夹下,有src、lib和include三个文件夹,删除lib文件夹下的所有文件 -4. 将platform文件夹放到MxRec代码目录下 +由于构建脚本需要适配内部构建工程,所以在脚本中存在适配代码,但是这些代码可能对于用户来说不需要,所以在编译之前需要做如下处理: + +在build目录中存在build_tf1.sh和build_tf2.sh,其中分别存在如下代码: +```shell +# 配置tf1路径 +source /opt/buildtools/tf1_env/bin/activate +tf1_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow_core +deactivate tf1_env +``` +```shell +# 配置tf2路径 +source /opt/buildtools/tf2_env/bin/activate +tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow +deactivate tf2_env +``` + +可以看到,上述代码中都有激活Python虚拟环境的步骤,因此用户有两种选择: + +1. 根据需要在/opt/buildtools/目录下(没有此目录需要先创建)创建tf1_env和tf2_env两个Python虚拟环境,并在虚拟环境中安装对应版本的Tensorflow +2. 将source /opt/buildtools/tf1_env/bin/activate和deactivate tf1_env注释掉或者删除或者source /opt/buildtools/tf2_env/bin/activate和deactivate tf2_env注释掉或者删除 -为了构建多个版本的whl包,编译脚本在python虚拟环境完成对应tensorflow版本的安装。用户可以根据实际情况调整编译脚本,指定tensorflow的安装路径。编译方法: -- build/build.sh:执行脚本完成tf1和tf2版本whl包的构建和打包。执行脚本前,请参考build/build_tf1.sh、build/build_tf2.sh创建对应的虚拟环境,在虚拟环境中完成对应tensorflow版本的安装,并修改对应的激活命令。 -- build/build_tf1_with_opensource.sh:执行脚本完成tf1版本whl包的构建,构建成功后,whl包在tf1_whl子目录下。执行脚本前,创建tf1虚拟环境,在虚拟环境中完成tensorflow 1.15.0版本的安装,并修改对应的激活命令。 -- build/build_tf2_with_opensource.sh:执行脚本完成tf2版本whl包的构建,构建成功后,whl包在tf2_whl子目录下。执行脚本前,创建tf2虚拟环境,在虚拟环境中完成tensorflow 2.6.5版本的安装,并修改对应的激活命令。 + +编译方法: + +进入mxRec代码目录: +- setup.py:此脚本供内部使用,用于同时构建tf1和tf2的mxRec包,用户通常只需要其中一个,所以建议使用下面两个脚本构建。 +- setup_tf1.py:执行脚本setup_tf1.py,比如:**python3.7 setup_tf1.py bdist_wheel**完成tf1版本whl包的构建,构建成功后,whl包在build/mindxsdk-mxrec/tf1_whl子目录下。 +- setup_tf2.py:执行脚本setup_tf2.py,比如:**python3.7 setup_tf2.py bdist_wheel**完成tf2版本whl包的构建,构建成功后,whl包在build/mindxsdk-mxrec/tf2_whl子目录下。 如需使用动态扩容功能,进入“./cust_op/cust_op_by_addr”目录中。参考以下命令编译并安装动态扩容算子包。 ```shell bash run.sh ``` +## 测试用例 + +### Python侧测试用例 + +运行Python测试用例所需依赖: + +- pytest 7.1.1 +- pytest-cov 4.1.0 +- pytest-html + +如需使用python测试用例,需要先安装上述依赖以及能够在tf1环境下进行源码编译,然后进入tests目录中。参考以下命令执行python侧测试用例: +```shell +bash run_python_dt.sh +``` + +### C++侧测试用例 + +运行C++侧测试用例所需依赖: + +- [googletest 1.8.1](https://github.com/google/googletest/archive/refs/tags/release-1.8.1.zip) +- [emock 0.9.0](https://github.com/ez8-co/emock/archive/refs/tags/v0.9.0.zip) +- [pybind11 v2.10.3](https://github.com/pybind/pybind11/archive/refs/tags/v2.10.3.zip) +- [securec](https://github.com/huaweicloud/huaweicloud-sdk-c-obs/archive/refs/tags/v3.23.9.zip) + +将googletest、emock、pybind11和securec的压缩包放在与mxRec代码同级的opensource目录下,并且将其分别更名为googletest-release-1.8.1.zip、 +emock-0.9.0.zip、pybind11-2.10.3.zip、 huaweicloud-sdk-c-obs-3.23.9.zip。如果没有opensource目录,则需要在mxRec同级的目录下手动创建opensource目录, +然后将前述几个压缩包放在opensource目录下。 + +如需使用C++测试用例,需要按照上述描述准备需要的依赖,准备好之后,进入src目录中。参考以下命令执行C++测试用例: + +tf1环境下使用如下命令: +```shell +bash test_ut.sh tf1 +``` + +tf2环境下使用如下命令: +```shell +bash test_ut.sh tf2 +``` + ## 使用指导 -mxRec所支持的使用环境、功能特性、API接口与使用样例请参考昇腾开源社区MindX SDK产品文档。 +mxRec所支持的使用环境、功能特性、API接口与使用样例请参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0001.html)。 ## 参考设计 -mxrec框架基础镜像,基于TensorFlow 1.15.0、tensorflow2.6.5制作的基础镜像,安装mxrec后即可开始训练,以及样例使用介绍。 +mxRec框架基础镜像,基于TensorFlow 1.15.0、tensorflow2.6.5制作的基础镜像,安装mxRec后即可开始训练,以及样例使用介绍。 -1. https://ascendhub.huawei.com/#/detail/mxrec-tf1 +1. https://www.hiascend.com/developer/ascendhub/detail/mxrec-tf1 -2. https://ascendhub.huawei.com/#/detail/mxrec-tf2 +2. https://www.hiascend.com/developer/ascendhub/detail/mxrec-tf2 diff --git a/build/build_tf1_with_opensource.sh b/build/build_tf1.sh similarity index 71% rename from build/build_tf1_with_opensource.sh rename to build/build_tf1.sh index ff59571c541ae04037a80125421dc1a2a3eeb0e4..5d6632d692b5181588ef83cb58459b304f995334 100644 --- a/build/build_tf1_with_opensource.sh +++ b/build/build_tf1.sh @@ -15,13 +15,11 @@ # ============================================================================== ################################################################## -# build_tf1_with_opensource.sh 编译MxRec和动态扩容算子 +# build_tf1.sh 编译MxRec # 编译环境:Python3.7.5 GCC 7.3.0 CMake 3.20.6 -# 代码主要分为四部分: +# 代码主要分为两部分: # 1、准备编译MxRec所需依赖:pybind11(v2.10.3) securec # 2、编译securec、AccCTR以及MxRec -# 3、生成MxRec Wheel包,生成的whl包在当前目录下的mindxsdk-mxrec/tf1_whl -# 4、编译动态扩容算子 ################################################################## set -e @@ -64,33 +62,6 @@ source /opt/buildtools/tf1_env/bin/activate tf1_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env -project_output_path="${MxRec_DIR}"/output/ -VERSION_FILE="${MxRec_DIR}"/../mindxsdk/build/conf/config.yaml - -function get_version() { - if [ -f "$VERSION_FILE" ]; then - VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") - if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then - VERSION=${VERSION%.*} - fi - else - VERSION="5.0.0" - fi -} - -rm -rf "${project_output_path}" -rm -rf "${SCRIPT_DIR}/lib" - -# 获取MxRec版本信息 -get_version -export VERSION -echo "MindX SDK MxRec: ${VERSION}" >> ./version.info - -pkg_dir=mindxsdk-mxrec -rm -rf "${pkg_dir}" -mkdir "${pkg_dir}" -mv version.info "${pkg_dir}" - # 配置MxRec C++代码路径和AccCTR路径 src_path="${MxRec_DIR}"/src acc_ctr_path="${MxRec_DIR}"/src/AccCTR @@ -134,19 +105,10 @@ function collect_so_file() cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc cp -df "${MxRec_DIR}"/output/*.so* libasc cp "${opensource_path}"/securec/lib/libsecurec.so libasc -} - -function gen_wheel_file() -{ cd "${MxRec_DIR}" touch "${src_path}"/libasc/__init__.py rm -rf "${MxRec_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${MxRec_DIR}"/mx_rec - python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) - mkdir -p "$1" - echo "moving whl file $1" - mv dist/mx_rec*.whl "$1" - rm -rf "${MxRec_DIR}"/mx_rec/libasc } # start to build MxRec @@ -158,13 +120,4 @@ echo "---------------- compile MxRec so files ----------------" compile_so_file "${tf1_path}" echo "---------------- collect so files and mv them to libasc ----------------" collect_so_file -echo "---------------- generate MxRec wheel package ----------------" -gen_wheel_file "$SCRIPT_DIR"/"${pkg_dir}"/tf1_whl echo "---------------- compile MxRec success!!!! ----------------" - -# start to compile cust op -echo "---------------- start to compile cust op ----------------" -cd "${MxRec_DIR}"/cust_op/cust_op_by_addr -chmod u+x run.sh -./run.sh -echo "---------------- compile cust op success!!!! ----------------" \ No newline at end of file diff --git a/build/build_tf2_with_opensource.sh b/build/build_tf2.sh similarity index 71% rename from build/build_tf2_with_opensource.sh rename to build/build_tf2.sh index 08aaf1644aa9251a0051d42f00c46e0d44cf7ad8..639024ffeb820d724151d20962c331130d35a0d6 100644 --- a/build/build_tf2_with_opensource.sh +++ b/build/build_tf2.sh @@ -15,13 +15,11 @@ # ============================================================================== ################################################################## -# build_tf2_with_opensource.sh 编译MxRec和动态扩容算子 +# build_tf2.sh 编译MxRec # 编译环境:Python3.7.5 GCC 7.3.0 CMake 3.20.6 -# 代码主要分为四部分: +# 代码主要分为两部分: # 1、准备编译MxRec所需依赖:pybind11(v2.10.3) securec # 2、编译securec、AccCTR以及MxRec -# 3、生成MxRec Wheel包,生成的whl包在当前目录下的mindxsdk-mxrec/tf2_whl -# 4、编译动态扩容算子 ################################################################## set -e @@ -64,33 +62,6 @@ source /opt/buildtools/tf2_env/bin/activate tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env -project_output_path="${MxRec_DIR}"/output/ -VERSION_FILE="${MxRec_DIR}"/../mindxsdk/build/conf/config.yaml - -function get_version() { - if [ -f "$VERSION_FILE" ]; then - VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") - if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then - VERSION=${VERSION%.*} - fi - else - VERSION="5.0.0" - fi -} - -rm -rf "${project_output_path}" -rm -rf "${SCRIPT_DIR}/lib" - -# 获取MxRec版本信息 -get_version -export VERSION -echo "MindX SDK MxRec: ${VERSION}" >> ./version.info - -pkg_dir=mindxsdk-mxrec -rm -rf "${pkg_dir}" -mkdir "${pkg_dir}" -mv version.info "${pkg_dir}" - # 配置MxRec C++代码路径和AccCTR路径 src_path="${MxRec_DIR}"/src acc_ctr_path="${MxRec_DIR}"/src/AccCTR @@ -134,19 +105,10 @@ function collect_so_file() cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc cp -df "${MxRec_DIR}"/output/*.so* libasc cp "${opensource_path}"/securec/lib/libsecurec.so libasc -} - -function gen_wheel_file() -{ cd "${MxRec_DIR}" touch "${src_path}"/libasc/__init__.py rm -rf "${MxRec_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${MxRec_DIR}"/mx_rec - python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) - mkdir -p "$1" - echo "moving whl file $1" - mv dist/mx_rec*.whl "$1" - rm -rf "${MxRec_DIR}"/mx_rec/libasc } # start to build MxRec @@ -158,13 +120,4 @@ echo "---------------- compile MxRec so files ----------------" compile_so_file "${tf2_path}" echo "---------------- collect so files and mv them to libasc ----------------" collect_so_file -echo "---------------- generate MxRec wheel package ----------------" -gen_wheel_file "$SCRIPT_DIR"/"${pkg_dir}"/tf2_whl echo "---------------- compile MxRec success!!!! ----------------" - -# start to compile cust op -echo "---------------- start to compile cust op ----------------" -cd "${MxRec_DIR}"/cust_op/cust_op_by_addr -chmod u+x run.sh -./run.sh -echo "---------------- compile cust op success!!!! ----------------" \ No newline at end of file diff --git a/build/build.sh b/build/gen_mxrec_tar_pkg.sh similarity index 44% rename from build/build.sh rename to build/gen_mxrec_tar_pkg.sh index 0eb688fd4145a5111c72825c105b99cbf1a9d464..b5cba7a2e1b46ccd1ee700e38197156becf5d3f6 100644 --- a/build/build.sh +++ b/build/gen_mxrec_tar_pkg.sh @@ -18,11 +18,9 @@ set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } ARCH="$(uname -m)" SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -ROOT_DIR=$(dirname "${SCRIPT_DIR}") -cd "$SCRIPT_DIR" +MxRec_DIR=$(dirname "${SCRIPT_DIR}") - -VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml +VERSION_FILE="${MxRec_DIR}"/../mindxsdk/build/conf/config.yaml get_version() { if [ -f "$VERSION_FILE" ]; then VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") @@ -30,96 +28,57 @@ get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.0" - fi -} - -remove() -{ - if [ -d "$1" ]; then - rm -rf "$1" - elif [ -f "$1" ]; then - rm -f "$1" + VERSION="6.0.RC2" fi } -project_output_path="${ROOT_DIR}"/output/ -remove "${project_output_path}" -remove "${SCRIPT_DIR}/lib" get_version -export VERSION echo "MindX SDK mxrec: ${VERSION}" >> ./version.info pkg_dir=mindxsdk-mxrec -remove "${pkg_dir}" -mkdir "${pkg_dir}" -mv version.info "${pkg_dir}" - -src_path="${ROOT_DIR}"/src -cd "${ROOT_DIR}" - release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz +mv version.info "${SCRIPT_DIR}"/"${pkg_dir}" -gen_tar_file() +function gen_tar_file() { - cd "${src_path}" - cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" - cp -r "${src_path}"/../examples ../build/"${pkg_dir}" + cd "${MxRec_DIR}" + cp -r ./cust_op ./build/"${pkg_dir}" + cp -r ./examples ./build/"${pkg_dir}" # change dirs and files 's permission - chmod 550 ../build/"${pkg_dir}"/tf1_whl - chmod 550 ../build/"${pkg_dir}"/tf1_whl/mx_rec*.whl - chmod 550 ../build/"${pkg_dir}"/tf2_whl - chmod 550 ../build/"${pkg_dir}"/tf2_whl/mx_rec*.whl - chmod 550 ../build/"${pkg_dir}"/cust_op/ - chmod 550 ../build/"${pkg_dir}"/cust_op/cust_op_by_addr - cd ../build/"${pkg_dir}"/cust_op/cust_op_by_addr + chmod 550 ./build/"${pkg_dir}"/tf1_whl + chmod 550 ./build/"${pkg_dir}"/tf1_whl/mx_rec*.whl + chmod 550 ./build/"${pkg_dir}"/tf2_whl + chmod 550 ./build/"${pkg_dir}"/tf2_whl/mx_rec*.whl + chmod 550 ./build/"${pkg_dir}"/cust_op/ + chmod 550 ./build/"${pkg_dir}"/cust_op/cust_op_by_addr + cd ./build/"${pkg_dir}"/cust_op/cust_op_by_addr chmod 550 *.sh chmod 640 *.json chmod 550 op_host op_kernel op_host/* op_kernel/* cd - - cd ../build + cd ./build/"${pkg_dir}"/cust_op/ + chmod 550 -R fused_lazy_adam + chmod 640 fused_lazy_adam/*.json + cd - + cd ./build tar -zvcf "${release_tar}" "${pkg_dir}" || { warn "compression failed, packages might be broken" } - mv "${release_tar}" "${SCRIPT_DIR}"/../output/ + mv "${release_tar}" ../output/ } -clean() +function clean() { - remove "${ROOT_DIR}"/dist - remove "${ROOT_DIR}"/install - remove "${ROOT_DIR}"/mx_rec.egg-info - remove "${ROOT_DIR}"/src/build - remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" - remove "${ROOT_DIR}"/build/tf2_env - remove "${ROOT_DIR}"/build/tf1_env - remove "${ROOT_DIR}"/build/lib - remove "${ROOT_DIR}"/build/mindxsdk-mxrec + rm -rf "${MxRec_DIR}"/dist + rm -rf "${MxRec_DIR}"/mx_rec.egg-info + rm -rf "${MxRec_DIR}"/src/build + rm -rf "${MxRec_DIR}"/mx_rec/libasc + rm -rf "${MxRec_DIR}"/build/lib + rm -rf "${MxRec_DIR}"/build/bdist.linux-${ARCH} } +gen_tar_file -if [ "$(uname -m)" = "x86_64" ] -then - echo "-----Build gen tar -----" - bash ${ROOT_DIR}/build/build_tf1_with_opensource.sh - bash ${ROOT_DIR}/build/build_tf2_with_opensource.sh - gen_tar_file - echo "-----Build gen tar finished-----" - - # clean - echo "-----Done-----" -fi - -if [ "$(uname -m)" = "aarch64" ] -then - echo "-----Build gen tar -----" - bash ${ROOT_DIR}/build/build_tf1_with_opensource.sh - bash ${ROOT_DIR}/build/build_tf2_with_opensource.sh - gen_tar_file - echo "-----Build gen tar finished-----" - - # clean - echo "-----Done-----" -fi \ No newline at end of file +clean diff --git a/build/move_whl_file_2_pkg_dir.sh b/build/move_whl_file_2_pkg_dir.sh new file mode 100644 index 0000000000000000000000000000000000000000..d489c2fb8e0bc1a836e69af38acf9f1e39b530b1 --- /dev/null +++ b/build/move_whl_file_2_pkg_dir.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } +ARCH="$(uname -m)" +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +MxRec_DIR=$(dirname "${SCRIPT_DIR}") +pkg_dir=mindxsdk-mxrec +tf_version=$1 + +function move_whl_file_2_pkg_dir() { + mkdir -p "$SCRIPT_DIR"/"${pkg_dir}"/"${tf_version}"_whl + rm -rf "$SCRIPT_DIR"/"${pkg_dir}"/"${tf_version}"_whl/* + mv ${MxRec_DIR}/dist/mx_rec*.whl "$SCRIPT_DIR"/"${pkg_dir}"/"${tf_version}"_whl + cd "$SCRIPT_DIR"/"${pkg_dir}"/"${tf_version}"_whl + whl_file=$(ls .) + mv "$whl_file" "${whl_file/any/linux_${ARCH}}" + cd - +} + +move_whl_file_2_pkg_dir \ No newline at end of file diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/common.h b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/common.h new file mode 100644 index 0000000000000000000000000000000000000000..954f3f33b8f10fbdee517c77d89051091cb75f3f --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/common.h @@ -0,0 +1,45 @@ +/** +* @file common.h +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#ifndef COMMON_H +#define COMMON_H + +#include +#include +#include +#include +#include + +#include "acl/acl.h" + +#define SUCCESS 0 +#define FAILED 1 + +#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args) +#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args) +#define ERROR_LOG(fmt, args...) fprintf(stderr, "[ERROR] " fmt "\n", ##args) + +/** + * @brief Read data from file + * @param [in] filePath: file path + * @param [out] fileSize: file size + * @return read result + */ +bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize); + +/** + * @brief Write data to file + * @param [in] filePath: file path + * @param [in] buffer: data to write to file + * @param [in] size: size to write + * @return write result + */ +bool WriteFile(const std::string &filePath, const void *buffer, size_t size); + +#endif // COMMON_H diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/op_runner.h b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/op_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..03d0aff4241f81418bf70f0c70d21506e8f3a695 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/op_runner.h @@ -0,0 +1,182 @@ +/** +* @file op_runner.h +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#ifndef OP_RUNNER_H +#define OP_RUNNER_H + +#include "aclnn/acl_meta.h" +#include "acl/acl.h" +#include "common.h" +#include "operator_desc.h" + +/** + * Op Runner + */ +class OpRunner { +public: + /** + * @brief Constructor + * @param [in] opDesc: op description + */ + explicit OpRunner(OperatorDesc *opDesc); + + /** + * @brief Destructor + */ + virtual ~OpRunner(); + + /** + * @brief Init op runner + */ + bool Init(); + + /** + * @brief Get number of inputs + * @return number of inputs + */ + const size_t NumInputs(); + + /** + * @brief Get number of outputs + * @return number of outputs + */ + const size_t NumOutputs(); + + /** + * @brief Get input size by index + * @param [in] index: input index + * @return size of the input + */ + const size_t GetInputSize(size_t index) const; + const size_t GetInputNumDims(size_t index) const; + aclDataType GetInputDataType(size_t index) const; + aclFormat GetInputFormat(size_t index) const; + + /** + * @brief Get output size by index + * @param [in] index: output index + * @return size of the output + */ + size_t GetOutputSize(size_t index) const; + const size_t GetOutputNumDims(size_t index) const; + aclDataType GetOutputDataType(size_t index) const; + aclFormat GetOutputFormat(size_t index) const; + + /** + * @brief Get input element count by index + * @param i[in] ndex: input index + * @return element count of the input + */ + size_t GetInputElementCount(size_t index) const; + + /** + * @brief Get output element count by index + * @param [in] index: output index + * @return element count of the output + */ + size_t GetOutputElementCount(size_t index) const; + + /** + * @brief Get input shape by index + * @param [in] index: input index + * @return shape of the output + */ + std::vector GetInputShape(size_t index) const; + + /** + * @brief Get output shape by index + * @param [in] index: output index + * @return shape of the output + */ + std::vector GetOutputShape(size_t index) const; + + /** + * @brief Get input buffer(host memory) by index + * @tparam T: data type + * @param [in] index: input index + * @return host address of the input + */ + template + T *GetInputBuffer(size_t index) + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return nullptr; + } + return reinterpret_cast(hostInputs_[index]); + } + + /** + * @brief Get output buffer(host memory) by index + * @tparam T: data type + * @param [in] index: output index + * @return host address of the output + */ + template + const T *GetOutputBuffer(size_t index) + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return nullptr; + } + + return reinterpret_cast(hostOutputs_[index]); + } + + /** + * @brief Print readable input by index + * @param [in] index: input index + * @param [in] elementsPerRow: number of elements per row + */ + void PrintInput(size_t index, size_t elementsPerRow = 16); + + /** + * @brief Print readable output by index + * @param [in] index: output index + * @param [in] elementsPerRow: number of elements per row + */ + void PrintOutput(size_t index, size_t elementsPerRow = 16); + + /** + * @brief Compile static op + * @return compile result + */ + bool CompileStaticOp(); + + /** + * @brief Compile dynamic op + * @return compile result + */ + bool CompileDynamicOp(); + + /** + * @brief Run op + * @return run result + */ + bool RunOp(); + +private: + size_t numInputs_; + size_t numOutputs_; + + std::vector inputBuffers_; + std::vector outputBuffers_; + + std::vector devInputs_; + std::vector devOutputs_; + + std::vector hostInputs_; + std::vector hostOutputs_; + + std::vector inputTensor_; + std::vector outputTensor_; + OperatorDesc *opDesc_; +}; + +#endif // OP_RUNNER_H diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/operator_desc.h b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/operator_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..da7198490f8f4cea66b8f8fe69b8f2ad9f424872 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/inc/operator_desc.h @@ -0,0 +1,57 @@ +/** +* @file operator_desc.h +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#ifndef OPERATOR_DESC_H +#define OPERATOR_DESC_H + +#include +#include + +#include "acl/acl.h" + +/** + * Op description + */ +struct OperatorDesc { + /** + * Constructor + */ + explicit OperatorDesc(); + + /** + * Destructor + */ + virtual ~OperatorDesc(); + + /** + * Add an input tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddInputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + /** + * Add an output tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddOutputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + std::string opType; + std::vector inputDesc; + std::vector outputDesc; +}; + +#endif // OPERATOR_DESC_H diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/run.sh b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/run.sh new file mode 100755 index 0000000000000000000000000000000000000000..6793de823cd9a4cadf2cd7305522b5a7f82f7151 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/run.sh @@ -0,0 +1,91 @@ +#!/bin/bash +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export ASCEND_GLOBAL_LOG_LEVEL=0 + +CURRENT_DIR=$( + cd $(dirname ${BASH_SOURCE:-$0}) + pwd +) +cd $CURRENT_DIR + +SHORT=v:, +LONG=dtype:, +OPTS=$(getopt -a --options $SHORT --longoptions $LONG -- "$@") +eval set -- "$OPTS" +while : +do + case "$1" in + # float16, float, int32 + (-v | --dtype) + DTYPE="$2" + shift 2;; + (--) + shift; + break;; + (*) + echo "[ERROR] Unexpected option: $1"; + break;; + esac +done + +if [ ! $ASCEND_HOME_DIR ]; then + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + export ASCEND_HOME_DIR=$HOME/Ascend/ascend-toolkit/latest + else + export ASCEND_HOME_DIR=/usr/local/Ascend/ascend-toolkit/latest + fi +fi + +export DDK_PATH=$ASCEND_HOME_DIR +arch=$(uname -m) +export NPU_HOST_LIB=$ASCEND_HOME_DIR/${arch}-linux/lib64 + +function main { + rm -rf $HOME/ascend/log/* + rm ./input/*.bin + rm ./output/*.bin + + cd $CURRENT_DIR + python3 scripts/gen_data.py + if [ $? -ne 0 ]; then + echo "ERROR: generate input data failed!" + return 1 + fi + echo "INFO: generate input data success!" + + cd $CURRENT_DIR; rm -rf build; mkdir -p build; cd build + cmake ../src + if [ $? -ne 0 ]; then + echo "ERROR: cmake failed!" + return 1 + fi + echo "INFO: cmake success!" + make + if [ $? -ne 0 ]; then + echo "ERROR: make failed!" + return 1 + fi + echo "INFO: make success!" + + cd $CURRENT_DIR/output + echo "INFO: execute op!" + ./execute_attention_fusion_grad_op + + if [ $? -ne 0 ]; then + echo "ERROR: acl executable run failed! please check your project!" + return 1 + fi + echo "INFO: acl executable run success!" + cd $CURRENT_DIR + ret=`python3 scripts/verify_result.py output/grad_query.bin output/grad_key.bin output/grad_value.bin output/golden_grad_query.bin output/golden_grad_key.bin output/golden_grad_value.bin ` + echo $ret + if [ "x$ret" == "xtest pass" ]; then + echo "" + echo "#####################################" + echo "INFO: you have passed the Precision!" + echo "#####################################" + echo "" + fi +} + +main diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/scripts/gen_data.py b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/scripts/gen_data.py new file mode 100644 index 0000000000000000000000000000000000000000..69077ee3dd495479fbaf20dcf9b442923d77efed --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/scripts/gen_data.py @@ -0,0 +1,47 @@ +#!/usr/bin/python3 +# -*- coding:utf-8 -*- +# Copyright 2024 Huawei Technologies Co., Ltd +import numpy as np +import os +import math + +def softmax_grad(grad, src): + dst = grad * src + dst = np.sum(dst, axis=-1, keepdims=True) + dst = (grad - dst) * src + return dst + +def param_attn_layer_grad(dout, softmax_out, query, key, value): + # Dv and dS + dv = np.matmul(np.transpose(softmax_out, (0, 2, 1)), dout) + dS = np.matmul(dout, np.transpose(value, (0, 2, 1))) + dS = softmax_grad(dS, softmax_out)/math.sqrt(query.shape[2]) + # Atten + dQ = np.matmul(dS, key) + dK = np.matmul(np.transpose(dS, (0, 2, 1)), query) + return dQ, dK, dv + +def gen_golden_data_simple(): + + dout = np.random.uniform(-1, 1,[1024, 1000, 80]).astype(np.float32) + softmax_out = np.random.uniform(-1, 1,[1024, 1000, 50]).astype(np.float32) + query = np.random.uniform(-1, 1,[1024, 1000, 80]).astype(np.float32) + key = np.random.uniform(-1, 1,[1024, 50, 80]).astype(np.float32) + value = np.random.uniform(-1, 1,[1024, 50, 80]).astype(np.float32) + + grad_query, grad_key, grad_value = param_attn_layer_grad(dout, softmax_out, query, key, value) + + os.system("mkdir -p input") + os.system("mkdir -p output") + dout.tofile("./input/dout.bin") + softmax_out.tofile("./input/softmax_out.bin") + query.tofile("./input/query.bin") + key.tofile("./input/key.bin") + value.tofile("./input/value.bin") + + grad_query.tofile("./output/golden_grad_query.bin") + grad_key.tofile("./output/golden_grad_key.bin") + grad_value.tofile("./output/golden_grad_value.bin") + +if __name__ == "__main__": + gen_golden_data_simple() diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/scripts/verify_result.py b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/scripts/verify_result.py new file mode 100644 index 0000000000000000000000000000000000000000..7781d41f6d12eb21c1afeaff7f904491b98b5921 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/scripts/verify_result.py @@ -0,0 +1,34 @@ +#!/usr/bin/python3 +# -*- coding:utf-8 -*- +# Copyright 2024 Huawei Technologies Co., Ltd +import os +import sys +import numpy as np + +loss = 1e-3 +minimum = 10e-10 + +def verify_result(real_result, golden): + real_result = np.fromfile(real_result, dtype=np.float32) + golden = np.fromfile(golden, dtype=np.float32) + real_result = real_result[:golden.size] + print(real_result[:32]) + print(golden[:32]) + result = np.abs(real_result - golden) + deno = np.maximum(np.abs(real_result), np.abs(golden)) + result_atol = np.less_equal(result, loss) + result_rtol = np.less_equal(result / np.add(deno, minimum), loss) + if not result_rtol.all() and not result_atol.all(): + if np.sum(result_rtol == False) > real_result.size * loss and np.sum(result_atol == False) > real_result.size * loss: + print("[ERROR] result error") + return False + print("test pass") + return True + +if __name__ == '__main__': + print("=============================grad query============") + verify_result(sys.argv[1], sys.argv[4]) + print("=============================grad key============") + verify_result(sys.argv[2], sys.argv[5]) + print("=============================grad value============") + verify_result(sys.argv[3], sys.argv[6]) \ No newline at end of file diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/CMakeLists.txt b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f1459958d64f4a6c646dd809145c87bcccc14321 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/CMakeLists.txt @@ -0,0 +1,68 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +# CMake lowest version requirement +cmake_minimum_required(VERSION 3.5.1) + +# project information +project(acl_execute_attention_fusion_grad) + +# Compile options +add_compile_options(-std=c++11) + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "../output") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "../output") + +set(INC_PATH $ENV{DDK_PATH}) + +if (NOT DEFINED ENV{DDK_PATH}) + set(INC_PATH "/usr/local/Ascend/ascend-toolkit/latest") + message(STATUS "set default INC_PATH: ${INC_PATH}") +else () + message(STATUS "env INC_PATH: ${INC_PATH}") +endif() + +set(CUST_PKG_PATH "${INC_PATH}/opp/vendors/attention_fusion_grad/op_api") + +set(LIB_PATH $ENV{NPU_HOST_LIB}) + +# Dynamic libraries in the stub directory can only be used for compilation +if (NOT DEFINED ENV{NPU_HOST_LIB}) + set(LIB_PATH "/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64/stub/") + set(LIB_PATH1 "/usr/local/Ascend/ascend-toolkit/latest/atc/lib64/stub/") + message(STATUS "set default LIB_PATH: ${LIB_PATH}") +else () + message(STATUS "env LIB_PATH: ${LIB_PATH}") +endif() + +# Header path +include_directories( + ${INC_PATH}/runtime/include + ${INC_PATH}/atc/include + ../inc + ${CUST_PKG_PATH}/include +) + +# add host lib path +link_directories( + ${LIB_PATH} + ${LIB_PATH1} + ${CUST_PKG_PATH}/lib +) + +add_executable(execute_attention_fusion_grad_op + operator_desc.cpp + op_runner.cpp + main.cpp + op_runner.cpp + common.cpp +) + +target_link_libraries(execute_attention_fusion_grad_op + ascendcl + cust_opapi + acl_op_compiler + nnopbase + stdc++ +) + +install(TARGETS execute_attention_fusion_grad_op DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/common.cpp b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02eac9b41eee6065811c862c5aab773172e74b5e --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/common.cpp @@ -0,0 +1,79 @@ +/** +* @file common.cpp +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#include "common.h" + +#include +#include +#include +#include + +extern bool g_isDevice; + +bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize) +{ + struct stat sBuf; + int fileStatus = stat(filePath.data(), &sBuf); + if (fileStatus == -1) { + ERROR_LOG("failed to get file %s", filePath.c_str()); + return false; + } + if (S_ISREG(sBuf.st_mode) == 0) { + ERROR_LOG("%s is not a file, please enter a file", filePath.c_str()); + return false; + } + + std::ifstream file; + file.open(filePath, std::ios::binary); + if (!file.is_open()) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0) { + ERROR_LOG("file size is 0"); + file.close(); + return false; + } + if (size > bufferSize) { + ERROR_LOG("file size is larger than buffer size"); + file.close(); + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + file.close(); + return true; +} + +bool WriteFile(const std::string &filePath, const void *buffer, size_t size) +{ + if (buffer == nullptr) { + ERROR_LOG("Write file failed. buffer is nullptr"); + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + auto writeSize = write(fd, buffer, size); + (void) close(fd); + if (writeSize != size) { + ERROR_LOG("Write file Failed."); + return false; + } + + return true; +} diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/main.cpp b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e6aa83405c7c33423e5bebff444bf76025afedfa --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/main.cpp @@ -0,0 +1,182 @@ +/** +* @file main.cpp +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "op_runner.h" + +#include "common.h" + +bool g_isDevice = false; +int deviceId = 15; + +OperatorDesc CreateOpDesc() +{ + // define operator + std::vector dout { 1024, 1000, 80 }; + std::vector softmax_out { 1024, 1000, 50 }; + std::vector query { 1024, 1000, 80}; + std::vector key { 1024, 50, 80 }; + std::vector value { 1024, 50, 80 }; + + std::vector grad_query { 1024, 1000, 80}; + std::vector grad_key { 1024, 50, 80 }; + std::vector grad_value { 1024, 50, 80 }; + + aclFormat format = ACL_FORMAT_ND; + OperatorDesc opDesc; + opDesc.AddInputTensorDesc(ACL_FLOAT, dout.size(), dout.data(), format); + opDesc.AddInputTensorDesc(ACL_FLOAT, softmax_out.size(), softmax_out.data(), format); + opDesc.AddInputTensorDesc(ACL_FLOAT, query.size(), query.data(), format); + opDesc.AddInputTensorDesc(ACL_FLOAT, key.size(), key.data(), format); + opDesc.AddInputTensorDesc(ACL_FLOAT, value.size(), value.data(), format); + + opDesc.AddOutputTensorDesc(ACL_FLOAT, grad_query.size(), grad_query.data(), format); + opDesc.AddOutputTensorDesc(ACL_FLOAT, grad_key.size(), grad_key.data(), format); + opDesc.AddOutputTensorDesc(ACL_FLOAT, grad_value.size(), grad_value.data(), format); + return opDesc; +} + +bool SetInputData(OpRunner &runner) +{ + size_t fileSize = 0; + ReadFile("../input/dout.bin", fileSize, runner.GetInputBuffer(0), runner.GetInputSize(0)); + ReadFile("../input/softmax_out.bin", fileSize, runner.GetInputBuffer(1), runner.GetInputSize(1)); + ReadFile("../input/query.bin", fileSize, runner.GetInputBuffer(2), runner.GetInputSize(2)); + ReadFile("../input/key.bin", fileSize, runner.GetInputBuffer(3), runner.GetInputSize(3)); + ReadFile("../input/value.bin", fileSize, runner.GetInputBuffer(4), runner.GetInputSize(4)); + INFO_LOG("Set input success"); + return true; +} + +bool ProcessOutputData(OpRunner &runner) +{ + WriteFile("../output/grad_query.bin", runner.GetOutputBuffer(0), runner.GetOutputSize(0)); + WriteFile("../output/grad_key.bin", runner.GetOutputBuffer(1), runner.GetOutputSize(1)); + WriteFile("../output/grad_value.bin", runner.GetOutputBuffer(2), runner.GetOutputSize(2)); + INFO_LOG("Write output success"); + return true; +} + +void DestoryResource() +{ + bool flag = false; + if (aclrtResetDevice(deviceId) != ACL_SUCCESS) { + ERROR_LOG("Reset device %d failed", deviceId); + flag = true; + } + INFO_LOG("Reset Device success"); + if (aclFinalize() != ACL_SUCCESS) { + ERROR_LOG("Finalize acl failed"); + flag = true; + } + if (flag) { + ERROR_LOG("Destory resource failed"); + } else { + INFO_LOG("Destory resource success"); + } +} + +bool InitResource() +{ + std::string output = "../output"; + if (access(output.c_str(), 0) == -1) { + int ret = mkdir(output.c_str(), 0700); + if (ret == 0) { + INFO_LOG("Make output directory successfully"); + } + else { + ERROR_LOG("Make output directory fail"); + return false; + } + } + + // acl.json is dump or profiling config file + if (aclInit(NULL) != ACL_SUCCESS) { + ERROR_LOG("acl init failed"); + return false; + } + + if (aclrtSetDevice(deviceId) != ACL_SUCCESS) { + ERROR_LOG("Set device failed. deviceId is %d", deviceId); + (void)aclFinalize(); + return false; + } + INFO_LOG("Set device[%d] success", deviceId); + + // runMode is ACL_HOST which represents app is running in host + // runMode is ACL_DEVICE which represents app is running in device + aclrtRunMode runMode; + if (aclrtGetRunMode(&runMode) != ACL_SUCCESS) { + ERROR_LOG("Get run mode failed"); + DestoryResource(); + return false; + } + g_isDevice = (runMode == ACL_DEVICE); + INFO_LOG("Get RunMode[%d] success", runMode); + + return true; +} + +bool RunOp() +{ + // create op desc + OperatorDesc opDesc = CreateOpDesc(); + + // create Runner + OpRunner opRunner(&opDesc); + if (!opRunner.Init()) { + ERROR_LOG("Init OpRunner failed"); + return false; + } + + // Load inputs + if (!SetInputData(opRunner)) { + ERROR_LOG("Set input data failed"); + return false; + } + + // Run op + if (!opRunner.RunOp()) { + ERROR_LOG("Run op failed"); + return false; + } + + // process output data + if (!ProcessOutputData(opRunner)) { + ERROR_LOG("Process output data failed"); + return false; + } + + INFO_LOG("Run op success"); + return true; +} + +int main(int argc, char **argv) +{ + if (!InitResource()) { + ERROR_LOG("Init resource failed"); + return FAILED; + } + INFO_LOG("Init resource success"); + + if (!RunOp()) { + DestoryResource(); + return FAILED; + } + + DestoryResource(); + + return SUCCESS; +} diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/op_runner.cpp b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/op_runner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4df5eea58b98d90125cbb4d82fe6cc5889c66c20 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/op_runner.cpp @@ -0,0 +1,464 @@ +/** +* @file op_runner.cpp +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#include "op_runner.h" +#include "aclnn_attention_fusion_grad.h" +#include +#include +#include +#include "acl/acl_op_compiler.h" +#include "common.h" + +using namespace std; + +extern bool g_isDevice; + +OpRunner::OpRunner(OperatorDesc *opDesc) : opDesc_(opDesc) +{ + numInputs_ = opDesc->inputDesc.size(); + numOutputs_ = opDesc->outputDesc.size(); +} + +OpRunner::~OpRunner() +{ + for (size_t i = 0; i < numInputs_; ++i) { + (void)aclDestroyTensor(inputTensor_[i]); + (void)aclDestroyDataBuffer(inputBuffers_[i]); + (void)aclrtFree(devInputs_[i]); + if (g_isDevice) { + (void)aclrtFree(hostInputs_[i]); + } else { + (void)aclrtFreeHost(hostInputs_[i]); + } + } + + for (size_t i = 0; i < numOutputs_; ++i) { + (void)aclDestroyTensor(outputTensor_[i]); + (void)aclDestroyDataBuffer(outputBuffers_[i]); + (void)aclrtFree(devOutputs_[i]); + if (g_isDevice) { + (void)aclrtFree(hostOutputs_[i]); + } else { + (void)aclrtFreeHost(hostOutputs_[i]); + } + } +} + +bool OpRunner::Init() +{ + for (size_t i = 0; i < numInputs_; ++i) { + auto size = GetInputSize(i); + void *devMem = nullptr; + if (aclrtMalloc(&devMem, size, ACL_MEM_MALLOC_HUGE_FIRST) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for input[%zu] failed", i); + return false; + } + devInputs_.emplace_back(devMem); + inputBuffers_.emplace_back(aclCreateDataBuffer(devMem, size)); + + void *hostInput = nullptr; + if (g_isDevice) { + if (aclrtMalloc(&hostInput, size, ACL_MEM_MALLOC_HUGE_FIRST) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for input[%zu] failed", i); + return false; + } + } else { + if (aclrtMallocHost(&hostInput, size) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for input[%zu] failed", i); + return false; + } + } + if (hostInput == nullptr) { + ERROR_LOG("Malloc memory for input[%zu] failed", i); + return false; + } + hostInputs_.emplace_back(hostInput); + + aclTensor *inputTensor = aclCreateTensor(GetInputShape(i).data(), GetInputNumDims(i), GetInputDataType(i), + nullptr, 0, GetInputFormat(i), GetInputShape(i).data(), GetInputNumDims(i), devInputs_[i]); + if (inputTensor == nullptr) { + ERROR_LOG("Create Tensor for input[%zu] failed", i); + return false; + } + inputTensor_.emplace_back(inputTensor); + } + + for (size_t i = 0; i < numOutputs_; ++i) { + auto size = GetOutputSize(i); + void *devMem = nullptr; + if (aclrtMalloc(&devMem, size, ACL_MEM_MALLOC_HUGE_FIRST) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for output[%zu] failed", i); + return false; + } + devOutputs_.emplace_back(devMem); + outputBuffers_.emplace_back(aclCreateDataBuffer(devMem, size)); + + void *hostOutput = nullptr; + if (g_isDevice) { + if (aclrtMalloc(&hostOutput, size, ACL_MEM_MALLOC_HUGE_FIRST) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for output[%zu] failed", i); + return false; + } + } else { + if (aclrtMallocHost(&hostOutput, size) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for output[%zu] failed", i); + return false; + } + } + if (hostOutput == nullptr) { + ERROR_LOG("Malloc host memory for output[%zu] failed", i); + return false; + } + hostOutputs_.emplace_back(hostOutput); + + aclTensor *outputTensor = aclCreateTensor(GetOutputShape(i).data(), GetOutputNumDims(i), GetOutputDataType(i), + nullptr, 0, GetOutputFormat(i), GetOutputShape(i).data(), GetOutputNumDims(i), devOutputs_[i]); + if (outputTensor == nullptr) { + ERROR_LOG("Create Tensor for output[%zu] failed", i); + return false; + } + outputTensor_.emplace_back(outputTensor); + } + + return true; +} + +const size_t OpRunner::NumInputs() +{ + return numInputs_; +} + +const size_t OpRunner::NumOutputs() +{ + return numOutputs_; +} + +const size_t OpRunner::GetInputSize(size_t index) const +{ + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return 0; + } + + return aclGetTensorDescSize(opDesc_->inputDesc[index]); +} + +const size_t OpRunner::GetInputNumDims(size_t index) const +{ + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return 0; + } + + return aclGetTensorDescNumDims(opDesc_->inputDesc[index]); +} + +aclDataType OpRunner::GetInputDataType(size_t index) const +{ + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return ACL_DT_UNDEFINED; + } + + return aclGetTensorDescType(opDesc_->inputDesc[index]); +} + +aclFormat OpRunner::GetInputFormat(size_t index) const +{ + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return ACL_FORMAT_UNDEFINED; + } + + return aclGetTensorDescFormat(opDesc_->inputDesc[index]); +} + +std::vector OpRunner::GetInputShape(size_t index) const +{ + std::vector ret; + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return ret; + } + + auto desc = opDesc_->inputDesc[index]; + for (size_t i = 0; i < aclGetTensorDescNumDims(desc); ++i) { + int64_t dimSize; + if (aclGetTensorDescDimV2(desc, i, &dimSize) != ACL_SUCCESS) { + ERROR_LOG("get dims from tensor desc failed. dims index = %zu", i); + ret.clear(); + return ret; + } + ret.emplace_back(dimSize); + } + + return ret; +} + +size_t OpRunner::GetOutputSize(size_t index) const +{ + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return 0; + } + + return aclGetTensorDescSize(opDesc_->outputDesc[index]); +} + +const size_t OpRunner::GetOutputNumDims(size_t index) const +{ + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return 0; + } + + return aclGetTensorDescNumDims(opDesc_->outputDesc[index]); +} + +aclDataType OpRunner::GetOutputDataType(size_t index) const +{ + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return ACL_DT_UNDEFINED; + } + + return aclGetTensorDescType(opDesc_->outputDesc[index]); +} + + +aclFormat OpRunner::GetOutputFormat(size_t index) const +{ + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return ACL_FORMAT_UNDEFINED; + } + + return aclGetTensorDescFormat(opDesc_->outputDesc[index]); +} + +std::vector OpRunner::GetOutputShape(size_t index) const +{ + std::vector ret; + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return ret; + } + + auto desc = opDesc_->outputDesc[index]; + for (size_t i = 0; i < aclGetTensorDescNumDims(desc); ++i) { + int64_t dimSize; + if (aclGetTensorDescDimV2(desc, i, &dimSize) != ACL_SUCCESS) { + ERROR_LOG("get dims from tensor desc failed. dims index = %zu", i); + ret.clear(); + return ret; + } + ret.emplace_back(dimSize); + } + return ret; +} + +size_t OpRunner::GetInputElementCount(size_t index) const +{ + if (index >= opDesc_->inputDesc.size()) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return 0; + } + + return aclGetTensorDescElementCount(opDesc_->inputDesc[index]); +} + +size_t OpRunner::GetOutputElementCount(size_t index) const +{ + if (index >= opDesc_->outputDesc.size()) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return 0; + } + + return aclGetTensorDescElementCount(opDesc_->outputDesc[index]); +} + +bool OpRunner::RunOp() +{ + for (size_t i = 0; i < numInputs_; ++i) { + auto size = GetInputSize(i); + aclrtMemcpyKind kind = ACL_MEMCPY_HOST_TO_DEVICE; + if (g_isDevice) { + kind = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + if (aclrtMemcpy(devInputs_[i], size, hostInputs_[i], size, kind) != ACL_SUCCESS) { + ERROR_LOG("Copy input[%zu] failed", i); + return false; + } + INFO_LOG("Copy input[%zu] success", i); + } + + aclrtStream stream = nullptr; + if (aclrtCreateStream(&stream) != ACL_SUCCESS) { + ERROR_LOG("Create stream failed"); + return false; + } + INFO_LOG("Create stream success"); + + size_t workspaceSize = 0; + aclOpExecutor *handle = nullptr; + auto ret = aclnnAttentionFusionGradGetWorkspaceSize(inputTensor_[0], inputTensor_[1], inputTensor_[2], inputTensor_[3], inputTensor_[4], outputTensor_[0], outputTensor_[1], outputTensor_[2], + &workspaceSize, &handle); + if (ret != ACL_SUCCESS) { + (void)aclrtDestroyStream(stream); + ERROR_LOG("Get Operator Workspace failed. error code is %d", static_cast(ret)); + return false; + } + INFO_LOG("Execute aclnnAttentionFusionGradGetWorkspaceSize success, workspace size %lu", workspaceSize); + + void *workspace = nullptr; + if (workspaceSize != 0) { + if (aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory failed"); + } + } + + ret = aclnnAttentionFusionGrad(workspace, workspaceSize, handle, stream); + if (ret != ACL_SUCCESS) { + (void)aclrtDestroyStream(stream); + ERROR_LOG("Execute Operator failed. error code is %d", static_cast(ret)); + return false; + } + INFO_LOG("Execute aclnnAttentionFusionGrad success"); + + ret = aclrtSynchronizeStreamWithTimeout(stream, 5000); + if (ret != SUCCESS) { + ERROR_LOG("Synchronize stream failed. error code is %d", static_cast(ret)); + (void)aclrtDestroyStream(stream); + return false; + } + INFO_LOG("Synchronize stream success"); + + auto beforeTime = std::chrono::steady_clock::now(); + for (int i = 0; i<100; i++) { + ret = aclnnAttentionFusionGradGetWorkspaceSize(inputTensor_[0], inputTensor_[1], inputTensor_[2], inputTensor_[3], inputTensor_[4], outputTensor_[0], outputTensor_[1], outputTensor_[2], + &workspaceSize, &handle); + ret = aclnnAttentionFusionGrad(workspace, workspaceSize, handle, stream); + } + ret = aclrtSynchronizeStreamWithTimeout(stream, 5000); + auto afterTime = std::chrono::steady_clock::now(); + double duration_microsecond = std::chrono::duration(afterTime - beforeTime).count(); + std::cout << "time cost " << duration_microsecond/100 << " us" << std::endl; + + for (size_t i = 0; i < numOutputs_; ++i) { + auto size = GetOutputSize(i); + aclrtMemcpyKind kind = ACL_MEMCPY_DEVICE_TO_HOST; + if (g_isDevice) { + kind = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + if (aclrtMemcpy(hostOutputs_[i], size, devOutputs_[i], size, kind) != ACL_SUCCESS) { + INFO_LOG("Copy output[%zu] success", i); + (void)aclrtDestroyStream(stream); + return false; + } + INFO_LOG("Copy output[%zu] success", i); + } + + (void)aclrtDestroyStream(stream); + return true; +} + + +template +void DoPrintData(const T *data, size_t count, size_t elementsPerRow) +{ + assert(elementsPerRow != 0); + for (size_t i = 0; i < count; ++i) { + std::cout << std::setw(10) << data[i]; + if (i % elementsPerRow == elementsPerRow - 1) { + std::cout << std::endl; + } + } +} + +void DoPrintFp16Data(const aclFloat16 *data, size_t count, size_t elementsPerRow) +{ + assert(elementsPerRow != 0); + for (size_t i = 0; i < count; ++i) { + std::cout << std::setw(10) << std::setprecision(4) << aclFloat16ToFloat(data[i]); + if (i % elementsPerRow == elementsPerRow - 1) { + std::cout << std::endl; + } + } +} + +void PrintData(const void *data, size_t count, aclDataType dataType, size_t elementsPerRow) +{ + if (data == nullptr) { + ERROR_LOG("Print data failed. data is nullptr"); + return; + } + + switch (dataType) { + case ACL_BOOL: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_INT8: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_UINT8: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_INT16: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_UINT16: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_INT32: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_UINT32: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_INT64: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_UINT64: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_FLOAT16: + DoPrintFp16Data(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_FLOAT: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + case ACL_DOUBLE: + DoPrintData(reinterpret_cast(data), count, elementsPerRow); + break; + default: + ERROR_LOG("Unsupported type: %d", dataType); + } +} + +void OpRunner::PrintInput(size_t index, size_t numElementsPerRow) +{ + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numInputs_); + return; + } + + auto desc = opDesc_->inputDesc[index]; + PrintData(hostInputs_[index], GetInputElementCount(index), aclGetTensorDescType(desc), numElementsPerRow); +} + +void OpRunner::PrintOutput(size_t index, size_t numElementsPerRow) +{ + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return; + } + + auto desc = opDesc_->outputDesc[index]; + PrintData(hostOutputs_[index], GetOutputElementCount(index), aclGetTensorDescType(desc), numElementsPerRow); +} diff --git a/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/operator_desc.cpp b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/operator_desc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1928103cf18c6a42ecd773dbe28f79553b4bcda2 --- /dev/null +++ b/cust_op/attention_fusion_grad/aclnn_attention_fusion_grad/src/operator_desc.cpp @@ -0,0 +1,56 @@ +/** +* @file operator_desc.cpp +* +* Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ +#include "common.h" +#include "operator_desc.h" + +using namespace std; + +OperatorDesc::OperatorDesc() {} + +OperatorDesc::~OperatorDesc() +{ + for (auto *desc : inputDesc) { + aclDestroyTensorDesc(desc); + } + + for (auto *desc : outputDesc) { + aclDestroyTensorDesc(desc); + } + +} + +OperatorDesc &OperatorDesc::AddInputTensorDesc(aclDataType dataType, + int numDims, + const int64_t *dims, + aclFormat format) +{ + aclTensorDesc *desc = aclCreateTensorDesc(dataType, numDims, dims, format); + if (desc == nullptr) { + ERROR_LOG("create tensor failed"); + return *this; + } + inputDesc.emplace_back(desc); + return *this; +} + +OperatorDesc &OperatorDesc::AddOutputTensorDesc(aclDataType dataType, + int numDims, + const int64_t *dims, + aclFormat format) +{ + aclTensorDesc *desc = aclCreateTensorDesc(dataType, numDims, dims, format); + if (desc == nullptr) { + ERROR_LOG("create tensor failed"); + return *this; + } + + outputDesc.emplace_back(desc); + return *this; +} diff --git a/cust_op/cust_op_by_addr/aclnn_op_test/inc/op_runner.h b/cust_op/cust_op_by_addr/aclnn_op_test/inc/op_runner.h index bf923d7ec4cdd199954289a27cb734421fd46c26..e41e35969146bf5a0de60d57a298bc8e167a5a72 100644 --- a/cust_op/cust_op_by_addr/aclnn_op_test/inc/op_runner.h +++ b/cust_op/cust_op_by_addr/aclnn_op_test/inc/op_runner.h @@ -140,16 +140,16 @@ public: /** * @brief Print readable input by index * @param [in] index: input index - * @param [in] elementsPerRow: number of elements per row + * @param [in] numElementsPerRow: number of elements per row */ - void PrintInput(size_t index, size_t elementsPerRow = 16); + void PrintInput(size_t index, size_t numElementsPerRow = 16); /** * @brief Print readable output by index * @param [in] index: output index - * @param [in] elementsPerRow: number of elements per row + * @param [in] numElementsPerRow: number of elements per row */ - void PrintOutput(size_t index, size_t elementsPerRow = 16); + void PrintOutput(size_t index, size_t numElementsPerRow = 16); /** * @brief Compile static op diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 4568177386f449189af207aa0d9921d46b63f323..722914d3d7a424830fafe9619c801b364b2414d7 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -16,13 +16,19 @@ See the License for the specific language governing permissions and #include "embedding_lookup_by_address_tiling.h" #include "register/op_def_registry.h" +namespace { + constexpr int32_t EMBEDDING_TYPE_FLOAT16 = 2; + constexpr int32_t EMBEDDING_TYPE_INT32 = 0; + constexpr int32_t EMBEDDING_TYPE_FLOAT32 = 1; +} + namespace optiling { constexpr int32_t BLOCK_DIM = 48; // 910b一张卡48个vector核 constexpr int32_t SIZE_OF_HALF = 2; constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; constexpr int32_t MIN_BLOCK_SIZE = 32; // ub空间的数据都要按照32对齐 - constexpr int32_t UB_LIMIT = 175 * 1024; + constexpr uint32_t UB_LIMIT = 175 * 1024; constexpr int32_t USR_SIZE = 256; constexpr int32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; constexpr int32_t PING_PONG_NUM = 1; @@ -81,7 +87,7 @@ namespace optiling int32_t inputShape = inputTensor->GetShapeSize(); int32_t typeSize = SIZE_OF_FLOAT_OR_INT; - if (embeddingType == 2) { + if (embeddingType == EMBEDDING_TYPE_FLOAT16) { typeSize = SIZE_OF_HALF; } // shape需要对齐到的最小单位, MIN_BLOCK_SIZE=32 @@ -92,7 +98,8 @@ namespace optiling int32_t occupyAddressBytesNum = sizeof(int64_t) + typeSize * embeddingDimAligned * PING_PONG_NUM * 2; // 一轮计算中最多计算多少个addr,由于地址也要搬到ub,所以需要对齐32, - int32_t addrPerLoop = (UB_LIMIT / occupyAddressBytesNum) & (~3); // & (~3),保证地址数是4的倍数 + int32_t addrPerLoop = static_cast((UB_LIMIT / + static_cast(occupyAddressBytesNum)) & (~3u)); // & (~3u),保证地址数是4的倍数 if (addrPerLoop <= 0) { return ge::GRAPH_FAILED; } @@ -116,6 +123,7 @@ namespace optiling namespace ge { + constexpr int OUTPUT_DIMENSION = 2; static ge::graphStatus InferShape1(gert::InferShapeContext *context) { @@ -140,8 +148,12 @@ namespace ge int64_t updateDim = *attr0Value; - int64_t inputShape = context->GetInputTensor(0)->GetShapeSize(); - yShape->SetDimNum(2); + auto *inputTensor2 = context->GetInputTensor(0); + if (optiling::CheckNullPointer(inputTensor2, "inputTensor2") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + int64_t inputShape = inputTensor2->GetShapeSize(); + yShape->SetDimNum(OUTPUT_DIMENSION); yShape->SetDim(0, inputShape); yShape->SetDim(1, updateDim); return GRAPH_SUCCESS; @@ -165,15 +177,15 @@ namespace ge } embbedingType = *attr1Value; - if (embbedingType == 0) + if (embbedingType == EMBEDDING_TYPE_INT32) { context->SetOutputDataType(0, ge::DataType(DT_INT32)); } - else if (embbedingType == 1) + else if (embbedingType == EMBEDDING_TYPE_FLOAT32) { context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); } - else if (embbedingType == 2) + else if (embbedingType == EMBEDDING_TYPE_FLOAT16) { context->SetOutputDataType(0, ge::DataType(DT_FLOAT16)); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index 5c45e2ab1ee9040abda96c381a64b68f1f3bdbb6..43d7a886ce036c18c0f9d8da554d3f18cbbf97a9 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -16,13 +16,19 @@ See the License for the specific language governing permissions and #include "embedding_update_by_address_tiling.h" #include "register/op_def_registry.h" +namespace { + constexpr int32_t EMBEDDING_TYPE_FLOAT16 = 2; + constexpr int32_t EMBEDDING_TYPE_INT32 = 0; + constexpr int32_t EMBEDDING_TYPE_FLOAT32 = 1; +} + namespace optiling { constexpr int32_t BLOCK_DIM = 48; // 910b一张卡48个vector核 constexpr int32_t SIZE_OF_HALF = 2; constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; constexpr int32_t MIN_BLOCK_SIZE = 32; // ub空间的数据都要按照32对齐 - constexpr int32_t UB_LIMIT = 175 * 1024; + constexpr uint64_t UB_LIMIT = 175 * 1024; constexpr int32_t USR_SIZE = 256; constexpr int32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; constexpr int32_t PING_PONG_NUM = 1; @@ -38,7 +44,7 @@ namespace optiling return ge::GRAPH_SUCCESS; } - static ge::graphStatus CheckPositiveInt(int32_t value, const char *errorMessage) + static ge::graphStatus CheckPositiveInt(int64_t value, const char *errorMessage) { if (value < 0) { printf("%s can not be smaller than 0\n", errorMessage); @@ -67,7 +73,7 @@ namespace optiling return ge::GRAPH_FAILED; } - int32_t inputShape = inputTensor->GetShapeSize(); + int64_t inputShape = inputTensor->GetShapeSize(); if (CheckPositiveInt(inputShape, "inputShape") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -77,8 +83,8 @@ namespace optiling return ge::GRAPH_FAILED; } - const int32_t inputShapeTmp = (inputShape > 0) ? inputShape : 1; - int32_t inputDim = inputTensor1->GetShapeSize() / inputShapeTmp; + const int64_t inputShapeTmp = (inputShape > 0) ? inputShape : 1; + int64_t inputDim = inputTensor1->GetShapeSize() / inputShapeTmp; if (CheckPositiveInt(inputDim, "inputDim") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -97,15 +103,15 @@ namespace optiling ge::DataType inputDatatype = inputTensor1->GetDataType(); int32_t embeddingType; if (inputDatatype == ge::DT_FLOAT16) { - embeddingType = 2; + embeddingType = EMBEDDING_TYPE_FLOAT16; } else if (inputDatatype == ge::DT_INT32) { - embeddingType = 0; + embeddingType = EMBEDDING_TYPE_INT32; } else { - embeddingType = 1; + embeddingType = EMBEDDING_TYPE_FLOAT32; } int32_t typeSize = SIZE_OF_FLOAT_OR_INT; - if (embeddingType == 2) { + if (embeddingType == EMBEDDING_TYPE_FLOAT16) { typeSize = SIZE_OF_HALF; } int32_t alignNum = MIN_BLOCK_SIZE / typeSize; @@ -116,7 +122,9 @@ namespace optiling int32_t occupyAddressBytesNum = sizeof(int64_t) + typeSize * inputDimAligned * PING_PONG_NUM * 2; // 一轮计算中最多计算多少个addr,由于地址也要搬到ub,所以需要对齐32 - int32_t addrPerLoop = (UB_LIMIT / occupyAddressBytesNum) & (~3); // & (~3),保证地址数是4的倍数 + int64_t addrPerLoop = static_cast( + UB_LIMIT / static_cast(occupyAddressBytesNum) & (~3U)); // & (~3U),保证地址数是4的倍数 + if (CheckPositiveInt(addrPerLoop, "addrPerLoop") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -125,8 +133,8 @@ namespace optiling tiling.set_update_type(updateType); tiling.set_embedding_type(embeddingType); - tiling.set_update_dim(inputDim); - tiling.set_addr_nums(inputShape); + tiling.set_update_dim(static_cast(inputDim)); + tiling.set_addr_nums(static_cast(inputShape)); tiling.set_addr_per_loop(addrPerLoop); tiling.set_type_size(typeSize); tiling.set_input_dim_aligned(inputDimAligned); diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 1a58768cbbc3d1b04e444d13661393df41639c11..f6a1e656a9f50d4a872503a0ce68585c018741d6 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -16,6 +16,8 @@ See the License for the specific language governing permissions and #include "kernel_operator.h" using namespace AscendC; +namespace AscendC { + constexpr int32_t SIZE_OF_HALF = 2; constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; constexpr int32_t PADDING_ZERO_NUM_PER_TIME = 8; @@ -32,7 +34,7 @@ public: needComputeAddrLen = singleCoreAddrLen; if (block_idx == block_num - 1) // 最后一个core,需要多计算的addr长度 { - needComputeAddrLen = addrNums * sizeof(int64_t) - singleCoreAddrLen * (block_num - 1); + needComputeAddrLen = addrNums * sizeof(int64_t) - singleCoreAddrLen * (block_num - 1); } loopCount = needComputeAddrLen / (addrNumPerLoop * sizeof(int64_t)); // 可能为0 @@ -42,6 +44,12 @@ public: pipe.InitBuffer(inQueue, pingpongNum, veclen); pipe.InitBuffer(outQueue, pingpongNum, veclen); +#ifdef L2_CACHE_HINT + // set `GlobalTensor` cache mode explicitly + srcAddrGlobal.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); + dstDataGm.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); +#endif + // get start index for current core, core parallel block_indx block_dim,即使是最后一个核也应该多初始化一些,并对齐4的倍数 srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * singleCoreAddrLen), needComputeAddrLen); dstDataGm.SetGlobalBuffer((__gm__ T *)(y)); @@ -73,105 +81,102 @@ public: if (loopCount > 0) { - for (int32_t i = 0; i < loopCount; i++) - { - DataCopy(srcAddrLocal, srcAddrGlobal[i * addrNumPerLoop], addrNumPerLoop); - MoveProcess(srcAddrLocal, i, addrNumPerLoop); - } + for (int32_t i = 0; i < loopCount; i++) { + DataCopy(srcAddrLocal, srcAddrGlobal[i * addrNumPerLoop], addrNumPerLoop); + MoveProcess(srcAddrLocal, i, addrNumPerLoop); + } } // 处理最后一张卡剩下的addr int unProcess = (needComputeAddrLen / sizeof(int64_t)) % addrNumPerLoop; if (unProcess) { - int unProcessAligned = (unProcess + 3) & (~3); // 处理 addressList 不对齐32b的情况 - // 地址列表访问越界,对齐考虑无问题,会自动多申请一部分,兼容 - DataCopy(srcAddrLocal, srcAddrGlobal[loopCount * addrNumPerLoop], unProcessAligned); - MoveProcess(srcAddrLocal, loopCount, unProcess); + int unProcessAligned = static_cast + ((static_cast(unProcess) + 3) & (~3U)); // 处理 addressList 不对齐32b的情况 + // 地址列表访问越界,对齐考虑无问题,会自动多申请一部分,兼容 + DataCopy(srcAddrLocal, srcAddrGlobal[loopCount * addrNumPerLoop], unProcessAligned); + MoveProcess(srcAddrLocal, loopCount, unProcess); } } private: - __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int addrNum) - { - set_flag(PIPE_MTE2, PIPE_S, 0); - wait_flag(PIPE_MTE2, PIPE_S, 0); - LocalTensor dataLocal = inQueue.AllocTensor(); // Queue的大小可以容下一个循环的所有emb - bool isFull = false; - int nums = 0; - int outIndex = 0; - int times = embDimAligned >> 3; // >>3位运算:除以8。 embDimAligned一定是8的倍数,若地址无效时,每次填充8个0 - int tmpCache = cache - 1; // 设计初是一次cache执行多次copyin、一次compute和一次copyout,现状是一次loop就只对应一次cache - - for (int i = 0; i < addrNum; i++) + __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int addrNum) { - // 多次copyIn, 对应一次compute和copyOut,由cache决定 - dataLocal = isFull ? inQueue.AllocTensor() : dataLocal; - int64_t address = srcAddrLocal.GetValue(i); - - if (address != 0) - { - srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(address), embDimAligned); - DataCopy(dataLocal[embDimAligned * nums], srcDataBufferGm, embDimAligned); - } - else - { - for (int j = 0; j < times; j++) + set_flag(PIPE_MTE2, PIPE_S, 0); + wait_flag(PIPE_MTE2, PIPE_S, 0); + LocalTensor dataLocal = inQueue.AllocTensor(); // Queue的大小可以容下一个循环的所有emb + bool isFull = false; + int nums = 0; + int outIndex = 0; + int times = embDimAligned >> 3; // >>3位运算:除以8。 embDimAligned一定是8的倍数,若地址无效时,每次填充8个0 + int tmpCache = cache - 1; // 设计初是一次cache执行多次copyin、一次compute和一次copyout,现状是一次loop就只对应一次cache + + for (int i = 0; i < addrNum; i++) { - Duplicate(dataLocal[embDimAligned * nums + j * PADDING_ZERO_NUM_PER_TIME], (T)0, PADDING_ZERO_NUM_PER_TIME); + // 多次copyIn, 对应一次compute和copyOut,由cache决定 + dataLocal = isFull ? inQueue.AllocTensor() : dataLocal; + int64_t address = srcAddrLocal.GetValue(i); + + if (address != 0) { +#ifdef L2_CACHE_HINT + srcDataBufferGm.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); +#endif + srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(address), embDimAligned); + DataCopy(dataLocal[embDimAligned * nums], srcDataBufferGm, embDimAligned); + } else { + for (int j = 0; j < times; j++) { + Duplicate(dataLocal[embDimAligned * nums + j * PADDING_ZERO_NUM_PER_TIME], + (T)0, PADDING_ZERO_NUM_PER_TIME); + } + } + + nums++; + isFull = (i == tmpCache || i == addrNum - 1); // cache满了,或者最后一个地址 + if (isFull) { + inQueue.EnQue(dataLocal); + Compute(nums); + CopyOut(outIndex, turns, nums); + nums = 0; + outIndex = i + 1; + tmpCache += cache; + } } - } - - nums++; - isFull = (i == tmpCache || i == addrNum - 1); // cache满了,或者最后一个地址 - if (isFull) - { - inQueue.EnQue(dataLocal); - Compute(nums); - CopyOut(outIndex, turns, nums); - nums = 0; - outIndex = i + 1; - tmpCache += cache; - } } - } - __aicore__ inline void Compute(const int nums) - { - // deque input tensors from VECIN queue - LocalTensor srcLocal = inQueue.DeQue(); - LocalTensor dstLocal = outQueue.AllocTensor(); + __aicore__ inline void Compute(const int nums) + { + // deque input tensors from VECIN queue + LocalTensor srcLocal = inQueue.DeQue(); + LocalTensor dstLocal = outQueue.AllocTensor(); - DataCopyParams copyParams; - copyParams.blockCount = 1; - copyParams.blockLen = (embDimAligned * sizeof(T) * nums) >> 5; // >> 5, 除以32,ub空间对齐 - DataCopy(dstLocal, srcLocal, copyParams); + DataCopyParams copyParams; + copyParams.blockCount = 1; + copyParams.blockLen = (embDimAligned * sizeof(T) * nums) >> 5; // >> 5, 除以32,ub空间对齐 + DataCopy(dstLocal, srcLocal, copyParams); - outQueue.EnQue(dstLocal); - inQueue.FreeTensor(srcLocal); - } + outQueue.EnQue(dstLocal); + inQueue.FreeTensor(srcLocal); + } - __aicore__ inline void CopyOut(const int index, const int turns, const int nums) - { - LocalTensor dstLocal = outQueue.DeQue(); + __aicore__ inline void CopyOut(const int index, const int turns, const int nums) + { + LocalTensor dstLocal = outQueue.DeQue(); - int offset = block_idx * dim * singleCoreAddrLen / sizeof(int64_t) + (turns * addrNumPerLoop * dim) + dim * index; + int offset = block_idx * dim * singleCoreAddrLen / + sizeof(int64_t) + (turns * addrNumPerLoop * dim) + dim * index; #if defined(__DAV_C220_VEC__) - if (typeSize == SIZE_OF_FLOAT_OR_INT) - { - copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, - nums, dim * sizeof(T), 0, 0, 0, 0); - } - else if (typeSize == SIZE_OF_HALF) - { - copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, - nums, dim * sizeof(T), 0, 0, 0, 0); - } + if (typeSize == SIZE_OF_FLOAT_OR_INT) { + copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm[offset].GetPhyAddr(), + (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, nums, dim * sizeof(T), 0, 0, 0, 0); + } else if (typeSize == SIZE_OF_HALF) { + copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm[offset].GetPhyAddr(), + (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, nums, dim * sizeof(T), 0, 0, 0, 0); + } #else - DataCopy(dstDataGm[offset], dstLocal, embDimAligned * nums); + DataCopy(dstDataGm[offset], dstLocal, embDimAligned * nums); #endif - outQueue.FreeTensor(dstLocal); - } + outQueue.FreeTensor(dstLocal); + } public: int32_t addrNumPerLoop, loopCount, singleCoreAddrLen, needComputeAddrLen, veclen, dim, pingpongNum, cache; @@ -186,6 +191,7 @@ private: GlobalTensor srcDataBufferGm, dstDataGm; GlobalTensor srcAddrGlobal; }; +} extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR address, GM_ADDR y, GM_ADDR usrWorkspace, GM_ADDR tiling) @@ -198,7 +204,7 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres { case 0: { - KernelEimtable op; + AscendC::KernelEimtable op; op.Init_param(tiling); op.Init(address, y); op.Process(); @@ -206,7 +212,7 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres break; case 2: { - KernelEimtable op; + AscendC::KernelEimtable op; op.Init_param(tiling); op.Init(address, y); op.Process(); @@ -214,7 +220,7 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres break; default: { - KernelEimtable op; + AscendC::KernelEimtable op; op.Init_param(tiling); op.Init(address, y); op.Process(); diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 988472602f5f867ec3732ed061a29399b739d014..50abf83c6d72b8d56abdbd4d0aea064c31ac5b7d 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -16,6 +16,7 @@ See the License for the specific language governing permissions and #include "kernel_operator.h" using namespace AscendC; +namespace KernelOps { constexpr int32_t SIZE_OF_HALF = 2; constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; @@ -31,7 +32,7 @@ public: needComputeAddrLen = singleCoreAddrLen; if (block_idx == block_num - 1) { - needComputeAddrLen = addrNums * sizeof(int64_t) - singleCoreAddrLen * (block_num - 1); + needComputeAddrLen = addrNums * sizeof(int64_t) - singleCoreAddrLen * (block_num - 1); } loopCount = needComputeAddrLen / (addrNumPerLoop * sizeof(int64_t)); @@ -39,9 +40,17 @@ public: pipe.InitBuffer(inQueue, pingpongNum, veclen); pipe.InitBuffer(outQueue, pingpongNum, veclen); +#ifdef L2_CACHE_HINT + // set `GlobalTensor` cache mode explicitly + srcAddrGlobal.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); + srcDataBufferGm.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); + outDataGm.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); +#endif + // get start index for current core, core parallel block_indx block_dim srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * singleCoreAddrLen)); - srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(embedding + block_idx * singleCoreAddrLen / sizeof(int64_t) * sizeof(T) * dim)); + srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(embedding + block_idx * singleCoreAddrLen + / sizeof(int64_t) * sizeof(T) * dim)); outDataGm.SetGlobalBuffer((__gm__ T *)(y)); } @@ -72,120 +81,111 @@ public: if (loopCount > 0) { - for (int32_t i = 0; i < loopCount; i++) - { - DataCopy(srcAddrLocal, srcAddrGlobal[i * addrNumPerLoop], addrNumPerLoop); - MoveProcess(srcAddrLocal, i, addrNumPerLoop); - } + for (int32_t i = 0; i < loopCount; i++) { + DataCopy(srcAddrLocal, srcAddrGlobal[i * addrNumPerLoop], addrNumPerLoop); + MoveProcess(srcAddrLocal, i, addrNumPerLoop); + } } int unProcess = (needComputeAddrLen / sizeof(int64_t)) % addrNumPerLoop; if (unProcess) { - int unProcessAligned = (unProcess + 3) & (~3); // 处理 addressList 不对齐32b的情况 - DataCopy(srcAddrLocal, srcAddrGlobal[loopCount * addrNumPerLoop], unProcessAligned); - MoveProcess(srcAddrLocal, loopCount, unProcess); + int unProcessAligned = (static_cast(unProcess) + 3) & (~3U); // 处理 addressList 不对齐32b的情况 + DataCopy(srcAddrLocal, srcAddrGlobal[loopCount * addrNumPerLoop], unProcessAligned); + MoveProcess(srcAddrLocal, loopCount, unProcess); } } private: - __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int addrNum) - { - set_flag(PIPE_MTE2, PIPE_S, 0); - wait_flag(PIPE_MTE2, PIPE_S, 0); - LocalTensor dataLocal; - - int64_t address = 0; - if (dim == inputDimAligned) // copyIn 和 compute一次,copyOut多次 + __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int addrNum) { - dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[turns * addrNumPerLoop * dim], addrNum * inputDimAligned); - inQueue.EnQue(dataLocal); - - Compute(addrNum); // 只有copyOut的管道支持拷贝到gm上 - - LocalTensor dstLocal = outQueue.DeQue(); - if (updateType == 0) - { - SetAtomicAdd(); - } - for (int i = 0; i < addrNum; i++) - { - address = srcAddrLocal.GetValue(i); - if (address != 0) + set_flag(PIPE_MTE2, PIPE_S, 0); + wait_flag(PIPE_MTE2, PIPE_S, 0); + LocalTensor dataLocal; + + int64_t address = 0; + if (dim == inputDimAligned) // copyIn 和 compute一次,copyOut多次 { - dstDataGm.SetGlobalBuffer((__gm__ T*)(address)); - DataCopy(dstDataGm, dstLocal[i * inputDimAligned], inputDimAligned); + dataLocal = inQueue.AllocTensor(); + DataCopy(dataLocal, srcDataBufferGm[turns * addrNumPerLoop * dim], addrNum * inputDimAligned); + inQueue.EnQue(dataLocal); + + Compute(addrNum); // 只有copyOut的管道支持拷贝到gm上 + + LocalTensor dstLocal = outQueue.DeQue(); + if (updateType == 0) { + SetAtomicAdd(); + } + for (int i = 0; i < addrNum; i++) { + address = srcAddrLocal.GetValue(i); + if (address != 0) { +#ifdef L2_CACHE_HINT + dstDataGm.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); +#endif + dstDataGm.SetGlobalBuffer((__gm__ T*)(address)); + DataCopy(dstDataGm, dstLocal[i * inputDimAligned], inputDimAligned); + } + } + if (updateType == 0) { + SetAtomicNone(); + } + outQueue.FreeTensor(dstLocal); + } else { + for (int i = 0; i < addrNum; i++) { + dataLocal = inQueue.AllocTensor(); + DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * addrNumPerLoop * dim], inputDimAligned); + inQueue.EnQue(dataLocal); + Compute(1); + address = srcAddrLocal.GetValue(i); + CopyOut(address, turns, i); + } } - } - if (updateType == 0) - { - SetAtomicNone(); - } - outQueue.FreeTensor(dstLocal); } - else + + __aicore__ inline void Compute(const int nums) { - for (int i = 0; i < addrNum; i++) - { - dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * addrNumPerLoop * dim], inputDimAligned); - inQueue.EnQue(dataLocal); - Compute(1); - address = srcAddrLocal.GetValue(i); - CopyOut(address, turns, i); - } + // deque input tensors from VECIN queue + LocalTensor srcLocal = inQueue.DeQue(); + LocalTensor dstLocal = outQueue.AllocTensor(); + DataCopyParams copyparams; + copyparams.blockCount = 1; + copyparams.blockLen = (inputDimAligned * sizeof(T) * nums) >> 5; // >> 5, 除以32,ub空间对齐 + DataCopy(dstLocal, srcLocal, copyparams); + outQueue.EnQue(dstLocal); + inQueue.FreeTensor(srcLocal); } - } - - __aicore__ inline void Compute(const int nums) - { - // deque input tensors from VECIN queue - LocalTensor srcLocal = inQueue.DeQue(); - LocalTensor dstLocal = outQueue.AllocTensor(); - DataCopyParams copyparams; - copyparams.blockCount = 1; - copyparams.blockLen = (inputDimAligned * sizeof(T) * nums) >> 5; // >> 5, 除以32,ub空间对齐 - DataCopy(dstLocal, srcLocal, copyparams); - outQueue.EnQue(dstLocal); - inQueue.FreeTensor(srcLocal); - } - - __aicore__ inline void CopyOut(const int64_t address, const int64_t turns, const int64_t index) - { - LocalTensor dstLocal = outQueue.DeQue(); - if (address != 0) + __aicore__ inline void CopyOut(const int64_t address, const int64_t turns, const int64_t index) { - dstDataGm.SetGlobalBuffer((__gm__ T *)(address)); + LocalTensor dstLocal = outQueue.DeQue(); - if (updateType == 0) - { - SetAtomicAdd(); - } + if (address != 0) { +#ifdef L2_CACHE_HINT + dstDataGm.SetL2CacheHint(CacheMode::CACHE_MODE_NORMAL); +#endif + dstDataGm.SetGlobalBuffer((__gm__ T *)(address)); + + if (updateType == 0) { + SetAtomicAdd(); + } #if defined(__DAV_C220_VEC__) - if (typeSize == SIZE_OF_FLOAT_OR_INT) - { - - copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, - 1, dim * sizeof(T), 0, 0, 0, 0); - } - else if (typeSize == SIZE_OF_HALF) - { - copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, - 1, dim * sizeof(T), 0, 0, 0, 0); - } + if (typeSize == SIZE_OF_FLOAT_OR_INT) { + copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, + 1, dim * sizeof(T), 0, 0, 0, 0); + } else if (typeSize == SIZE_OF_HALF) { + copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, + 1, dim * sizeof(T), 0, 0, 0, 0); + } #else - DataCopy(dstDataGm, dstLocal, inputDimAligned); + DataCopy(dstDataGm, dstLocal, inputDimAligned); #endif + } + if (updateType == 0) { + SetAtomicNone(); + } + outQueue.FreeTensor(dstLocal); } - if (updateType == 0) - { - SetAtomicNone(); - } - outQueue.FreeTensor(dstLocal); - } public: int32_t addrNumPerLoop, loopCount, singleCoreAddrLen, needComputeAddrLen, addrNums, cache, veclen, dim, pingpongNum; @@ -199,6 +199,7 @@ private: GlobalTensor srcDataBufferGm, dstDataGm, outDataGm; GlobalTensor srcAddrGlobal; }; +} extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR address, GM_ADDR embedding, GM_ADDR y, GM_ADDR usrWorkspace, GM_ADDR tiling) @@ -211,7 +212,7 @@ extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR addres { case 0: { - KernelEimtable_update op; + KernelOps::KernelEimtable_update op; op.Init_param(tiling); op.Init(address, embedding, y); op.Process(); @@ -219,7 +220,7 @@ extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR addres break; case 2: { - KernelEimtable_update op; + KernelOps::KernelEimtable_update op; op.Init_param(tiling); op.Init(address, embedding, y); op.Process(); @@ -227,7 +228,7 @@ extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR addres break; default: { - KernelEimtable_update op; + KernelOps::KernelEimtable_update op; op.Init_param(tiling); op.Init(address, embedding, y); op.Process(); diff --git a/cust_op/fused_lazy_adam/README.md b/cust_op/fused_lazy_adam/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3cb69f2d72c144be2b0b1b4b4817f60fc937c4b5 --- /dev/null +++ b/cust_op/fused_lazy_adam/README.md @@ -0,0 +1,184 @@ +# LazyAdam优化器融合算子及样例说明 + +## LazyAdam融合算子文件结构 + +```shell +├── aclnn_lazy_adam_test # 单算子测试用例 +├── lazy_adam.json # 算子原型配置 +├── op_host # LazyAdam融合算子Host侧实现 +├── op_kernel # LazyAdam融合算子Kernel侧实现 +├── README.md # LazyAdam融合算子说明文档 +└── run.sh # LazyAdam融合算子安装脚本 +``` + +## Ascend C参考设计 + +更多详情可以参考CANN官方的Ascend +C算子开发手册[Ascend C算子开发](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0001.html)。 + +## LazyAdam融合算子使用 + +1. 上传fused_lazy_adam文件夹到目标环境,并进入当前目录,执行指令对lazy_adam融合算子进行编译和部署 + +```shell +bash run.sh +``` + +注:需先在环境中设置CANN相关环境变量,再执行算子编译和安装指令。使用默认路径安装CANN时设置环境变量指令如下: + +```shell +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + +2. 模型脚本中创建lazy_adam优化器并指定使用融合算子实现。代码示例: + +```python +from mx_rec.optimizers.lazy_adam import create_hash_optimizer + +# 创建lazy_adam优化器时增加"use_fusion_optim=True"参数,表示使用融合算子实现。use_fusion_optim参数默认值为False。 +# lazy_adam优化器详细使用指导请参考mxRec用户指南。 +sparse_optimizer = create_hash_optimizer(learning_rate=0.001, use_fusion_optim=True) +``` + +## LazyAdam融合算子介绍 + +1. 算子分析 + +a) 算子的主要功能是实现lazy_adam优化器反向更新时m、v、variable三项数据的计算和更新; +b) 算子参数说明: + +* gradient: lazy_adam优化器计算时使用的梯度; +* indices: 参与计算/更新的数据索引; +* inputM: lazy_adam优化器一阶矩估计;计算结果原地更新; +* inputV: lazy_adam优化器二阶矩估计;计算结果原地更新; +* inputVar: embedding表对应的variable数据;计算结果原地更新; +* lr: 学习率; +* beta1: 一阶矩估计的指数衰减率; +* beta2: 二阶矩估计的指数衰减率; +* epsilon: 极小值; + +c) 算子约束说明: + +* 支持的型号:Atlas A2系列产品; +* 支持的CANN版本:8.0.RC1及之后版本; +* 支持的输入数据类型:float32; +* embedding表的dim值需要是8的倍数; + +2. Host侧算子实现 + +Host侧算子实现在目录 fused_lazy_adam/op_host下,其中包括:lazy_adam.cpp和 +lazy_adam_tiling.h。 + +a) Tiling实现 + +namespace +optiling域中的LazyAdamTilingFunc函数,主要实现从context中获取外部入参信息(输入参数指针、shape信息),及校验有效性; +并计算kernel侧需要的数据切分相关参数,包括row、loopCount、batch等(详情见tiling文件注释),设置BlockDim,最后通过TilingData传递属性信息。 + +b) Shape推导 + +因算子计算结果原地更新到输入参数中,namespace ge域中的InferShape和InferDataType函数体为空。 + +c) 原型注册 + +namespace ops域中的LazyAdam类定义了算子原型,并将算子注册到GE。 + +3. Kernel侧算子实现 + +Kernel侧算子实现在目录fused_lazy_adam/op_kernel下,其中包括:lazy_adam.cpp。 + +a) 核函数的入口:extern "C" __global__ __aicore__ void lazy_adam + +b) 解析tiling参数:GET_TILING_DATA(tilingData, tiling)从TilingData中获取host侧传入的数据 + +c) Init方法,进行算子运行数据的初始化; + +d) Process方法,进行数据搬入和计算,并且计算完成后将计算结果数据分别更新到对应入参中; + +## AclNN单算子测试参考设计 + +更多详情可以参考CANN官方的[Ascend C单算子调用概述](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0036.html)。 + +单算子调用分为两种方式:单算子API执行和模型执行。mxRec提供单算子API执行供参考。 + +单算子测试用例在目录fused_lazy_adam/aclnn_lazy_adam_test下,其中: + +* inc是头文件目录 +* scripts存放生成数据和验证数据的python脚本 +* input是存放算子入参的bin文件 +* output是存放生成的可执行程序execute_op、算子输出bin文件和用于验证的golden数据bin文件 +* src是存放公共函数common、构造算子输入输出描述类oprator_desc、单算子调用主体流程实现op_runner文件和入口main文件 + +执行单算子测试: + +```shell +bash run.sh +``` + +### 前置条件 + +1. + +参考[基于msopgen工具创建算子工程](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0023.html) +完成算子工程的创建, +参考[kernel侧算子实现](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0024.html) +完成kernel侧实现的相关准备, +参考[host侧算子实现](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0026.html) +完成host侧实现相关准备。 + +2. + +参考[算子编译部署](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0031.html) +完成算子的编译部署,编译部署时需要开启算子的二进制编译功能:修改算子工程中的编译配置项文件CMakePresets.json,将 +ENABLE_BINARY_PACKAGE设置为True。编译部署时可将算子的二进制部署到当前环境,便于后续算子的调用。 + +3. + +检查API执行需要的头文件和库文件是否自动生成,检查cust_op/fused_lazy_adam/lazy_adam/build_out/autogen目录下,是否有 +aclnn_lazy_adam.cpp和aclnn_lazy_adam.h等。 + +注意:对于cust_op/fused_lazy_adam/run.sh脚本,安装算子后会删除构建目录。运行单算子测试时,需要屏蔽掉删除rm rf +./lazy_adam这一步,以确保前置条件3。 + +### LazyAdam融合算子的AclNN调用实现 + +调用入口在src/main.cpp中: + +1. InitResource函数:初始化AscendCL并运行管理资源申请,不用修改 +2. RunLookupOp运行算子: + +a) 创建算子输入输出描述CreateOpDesc,OperatorDesc对象定义(inc/operator_desc.h)中设置了算子入参为成员变量,以便后续 +op_runner中使用; + +b) 创建OpRunner的对象,并依次执行: + +* opRunner.Init():申请内存存放执行算子的输入输出数据 +* SetInputData():加载数据输入bin文件并传输给OpRunner的Buffer供后续算子执行使用 +* opRunner.RunOp():算子执行,主要流程为:入参数据拷贝,创建Stream,执行Stream,输出数据拷贝,释放Stream资源 +* ProcessOutputData():算子输出数据处理,并落盘文件,以供后续与golden数据比对 + +3. DestroyResource函数:释放内存,不用修改 + +### 运行脚本 + +run.sh脚本依次执行: + +1. 清除遗留生成文件和日志文件 +2. 生成输入数据和真值数据 +3. 编译acl可执行文件 +4. 运行可执行文件 +5. 比较真值文件 + +### scripts脚本 + +* gen_data.py:生成LazyAdam融合算子的输入数据和用于精度校验的golden数据,可自行修改测试相关dim参数。 +* verify_result.py:将算子的输出和脚本生成的golden数据进行精度比对,并输出比较结果。比对规则为:允许误差精度loss:1e-6 + +a) 绝对误差 +b) 相对误差 +c) 误差相对个数 + +同时满足绝对误差不全小于loss,相对误差不全小于loss,且绝对误差和相对误差大于loss的个数都超过总数的1/loss,也就是 +1/1000000(百万分之一),即认为算子精度不达标。其余情况均认为算子达标。 + +用户可自行修改允许精度误差范围loss。 \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/common.h b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/common.h new file mode 100644 index 0000000000000000000000000000000000000000..601a261764b0975bec2ff3a1e734004ad94fa872 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/common.h @@ -0,0 +1,52 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef COMMON_H +#define COMMON_H + +#include +#include +#include +#include +#include + +#include "acl/acl.h" + +namespace AclnnLazyAdam { +#define SUCCESS 0 +#define FAILED 1 + +#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args) +#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args) +#define ERROR_LOG(fmt, args...) fprintf(stderr, "[ERROR] " fmt "\n", ##args) + + /** + * @brief Read data from file + * @param [in] filePath: file path + * @param [out] fileSize: file size + * @return read result + */ + bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize); + + /** + * @brief Write data to file + * @param [in] filePath: file path + * @param [in] buffer: data to write to file + * @param [in] size: size to write + * @return write result + */ + bool WriteFile(const std::string &filePath, const void *buffer, size_t size); +} +#endif // COMMON_H diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/op_runner.h b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/op_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..2e25341fa3d4257007c780f8e5abdc74260c82e9 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/op_runner.h @@ -0,0 +1,186 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef OP_RUNNER_H +#define OP_RUNNER_H + +#include "aclnn/acl_meta.h" +#include "acl/acl.h" +#include "common.h" +#include "operator_desc.h" + +namespace AclnnLazyAdam { + /** + * Op Runner + */ + class OpRunner { + public: + /** + * @brief Constructor + * @param [in] opDesc: op description + */ + explicit OpRunner(OperatorDesc* opDesc); + + /** + * @brief Destructor + */ + virtual ~OpRunner(); + + /** + * @brief Init op runner + */ + bool Init(); + + /** + * @brief Init op runner output info + */ + bool InitOutputInfo(); + + /** + * @brief Get number of inputs + * @return number of inputs + */ + const size_t NumInputs(); + + /** + * @brief Get number of outputs + * @return number of outputs + */ + const size_t NumOutputs(); + + /** + * @brief Get input size by index + * @param [in] index: input index + * @return size of the input + */ + const size_t GetInputSize(size_t index) const; + + const size_t GetInputNumDims(size_t index) const; + + aclDataType GetInputDataType(size_t index) const; + + aclFormat GetInputFormat(size_t index) const; + + /** + * @brief Get output size by index + * @param [in] index: output index + * @return size of the output + */ + size_t GetOutputSize(size_t index) const; + + const size_t GetOutputNumDims(size_t index) const; + + aclDataType GetOutputDataType(size_t index) const; + + aclFormat GetOutputFormat(size_t index) const; + + /** + * @brief Get input element count by index + * @param i[in] ndex: input index + * @return element count of the input + */ + size_t GetInputElementCount(size_t index) const; + + /** + * @brief Get output element count by index + * @param [in] index: output index + * @return element count of the output + */ + size_t GetOutputElementCount(size_t index) const; + + /** + * @brief Get input shape by index + * @param [in] index: input index + * @return shape of the output + */ + std::vector GetInputShape(size_t index) const; + + /** + * @brief Get output shape by index + * @param [in] index: output index + * @return shape of the output + */ + std::vector GetOutputShape(size_t index) const; + + /** + * @brief Get input buffer(host memory) by index + * @tparam T: data type + * @param [in] index: input index + * @return host address of the input + */ + template + T* GetInputBuffer(size_t index) + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return nullptr; + } + return reinterpret_cast(hostInputs_[index]); + } + + /** + * @brief Get output buffer(host memory) by index + * @tparam T: data type + * @param [in] index: output index + * @return host address of the output + */ + template + const T* GetOutputBuffer(size_t index) + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return nullptr; + } + + return reinterpret_cast(hostOutputs_[index]); + } + + /** + * @brief Compile static op + * @return compile result + */ + bool CompileStaticOp(); + + /** + * @brief Compile dynamic op + * @return compile result + */ + bool CompileDynamicOp(); + + /** + * @brief Run op + * @return run result + */ + bool RunOp(); + + private: + size_t numInputs_; + size_t numOutputs_; + + std::vector inputBuffers_; + std::vector outputBuffers_; + + std::vector devInputs_; + std::vector devOutputs_; + + std::vector hostInputs_; + std::vector hostOutputs_; + + std::vector inputTensor_; + std::vector outputTensor_; + OperatorDesc* opDesc_; + }; +} +#endif // OP_RUNNER_H diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/operator_desc.h b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/operator_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..ddd3b3a94e545708d2128f67d31d16ee709b2de3 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/inc/operator_desc.h @@ -0,0 +1,67 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef OPERATOR_DESC_H +#define OPERATOR_DESC_H + +#include +#include + +#include "acl/acl.h" + +namespace AclnnLazyAdam { + /** + * Op description + */ + struct OperatorDesc { + /** + * Constructor + */ + explicit OperatorDesc(); + + /** + * Destructor + */ + virtual ~OperatorDesc(); + + /** + * Add an input tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddInputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + /** + * Add an output tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddOutputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + std::string opType; + std::vector inputDesc; + std::vector outputDesc; + double beta1; + double beta2; + double epsilon; + }; +} +#endif // OPERATOR_DESC_H diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/input/.keep b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/input/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/output/.keep b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/output/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/run.sh b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..b44855dfd1b2133b39be0060aa6cc500d258161f --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/run.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export ASCEND_GLOBAL_LOG_LEVEL=0 + +CURRENT_DIR=$( + cd $(dirname ${BASH_SOURCE:-$0}) + pwd +) +cd $CURRENT_DIR + +# 导出环境变量 +SHORT=v:, +LONG=dtype:, +OPTS=$(getopt -a --options $SHORT --longoptions $LONG -- "$@") +eval set -- "$OPTS" +while : +do + case "$1" in + # float16, float, int32 + (-v | --dtype) + DTYPE="$2" + shift 2;; + (--) + shift; + break;; + (*) + echo "[ERROR] Unexpected option: $1"; + break;; + esac +done + +if [ ! $ASCEND_HOME_DIR ]; then + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + export ASCEND_HOME_DIR=$HOME/Ascend/ascend-toolkit/latest + else + export ASCEND_HOME_DIR=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $ASCEND_HOME_DIR/bin/setenv.bash + +export DDK_PATH=$ASCEND_HOME_DIR +arch=$(uname -m) +export NPU_HOST_LIB=$ASCEND_HOME_DIR/${arch}-linux/lib64 + +function main { + # 1. 清除遗留生成文件和日志文件 + rm -rf $HOME/ascend/log/* > /dev/null 2>&1 + rm ./input/*.bin > /dev/null 2>&1 + rm ./output/*.bin > /dev/null 2>&1 + + # 2. 生成输入数据和真值数据 + cd $CURRENT_DIR + python3 scripts/gen_data.py + if [ $? -ne 0 ]; then + echo "ERROR: generate input data failed!" + return 1 + fi + echo "INFO: generate input data success!" + + # 3. 编译acl可执行文件 + cd $CURRENT_DIR; rm -rf build; mkdir -p build; cd build + cmake ../src + if [ $? -ne 0 ]; then + echo "ERROR: cmake failed!" + return 1 + fi + echo "INFO: cmake success!" + make + if [ $? -ne 0 ]; then + echo "ERROR: make failed!" + return 1 + fi + echo "INFO: make success!" + + # 4. 运行可执行文件 + cd $CURRENT_DIR/output + echo "INFO: execute op!" + ./execute_op + + if [ $? -ne 0 ]; then + echo "ERROR: acl executable run failed! please check your project!" + return 1 + fi + echo "INFO: acl executable run success!" + + # 5. 比较真值文件 + cd $CURRENT_DIR + python3 scripts/verify_result.py +} + +main diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8c9251fe608e586c1f93207d8679306b7e4bc4 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/gen_data.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import copy +import os +import numpy as np + +# 获取项目路径 +_CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_PATH = os.path.dirname(_CURRENT_PATH) +_INPUT_PATH = os.path.join(_PROJECT_PATH, "./input") +_OUTPUT_PATH = os.path.join(_PROJECT_PATH, "./output") + +_DIM_0 = 2000000 +_DIM_1 = 564096 +_DIM_2 = 32 + + +def _gather(input_data, indices): + out = np.zeros((len(indices), input_data.shape[1])) + for i, index_ in enumerate(indices): + # 跳过index小于0的数据 + if index_[0] < 0: + continue + out[i] = input_data[index_[0]] + return out + + +def _scatter_nd_update(momentum, indices, update_value): + out = copy.deepcopy(momentum) + for i, index_ in enumerate(indices): + if index_[0] < 0: + continue + else: + out[index_[0]] = update_value[i] + return out + + +def _scatter_nd_add(momentum, indices, update_value): + out = copy.deepcopy(momentum) + for i, index_ in enumerate(indices): + if index_[0] < 0: + continue + else: + out[indices[i][0]] = out[index_[0]] + update_value[i] + return out + + +def _gen_input_data(): + range_start = 1 + range_end = 2 + + dtype_chose = np.float32 + shape0 = (_DIM_0, _DIM_2) + indices_shape = (_DIM_1, 1) + grad_shape = (_DIM_1, _DIM_2) + + input_var = np.random.uniform(range_start, range_end, size=shape0).astype(dtype_chose) # shape [2000000,32] + input_m = np.random.uniform(range_start, range_end, size=shape0).astype(dtype_chose) # shape [2000000,32] + input_v = np.random.uniform(range_start, range_end, size=shape0).astype(dtype_chose) # shape [2000000,32] + + # indices shape [564096,1] + indices = np.random.permutation(np.arange(_DIM_0)).astype(np.int32)[:indices_shape[0]].reshape(-1, 1) + # gradient shape [564096,32] + gradient = np.random.uniform(range_start, range_end, size=grad_shape).astype(dtype_chose) + + if not os.path.exists(_INPUT_PATH): + os.makedirs(_INPUT_PATH) + indices.tofile(os.path.join(_INPUT_PATH, "indices.bin")) + gradient.tofile(os.path.join(_INPUT_PATH, "gradient.bin")) + input_m.tofile(os.path.join(_INPUT_PATH, "inputM.bin")) + input_v.tofile(os.path.join(_INPUT_PATH, "inputV.bin")) + input_var.tofile(os.path.join(_INPUT_PATH, "inputVar.bin")) + + +def _gen_golden_data(): + beta1 = 0.9 + beta2 = 0.999 + lr = 0.001 + epsilon = 1e-7 + + lr = np.array(lr).astype(np.float32) + beta1 = np.array(beta1).astype(np.float32) + beta2 = np.array(beta2).astype(np.float32) + epsilon = np.array(epsilon).astype(np.float32) + + lr.tofile(os.path.join(_INPUT_PATH, "learningRate.bin")) + + indices = np.fromfile(os.path.join(_INPUT_PATH, "indices.bin"), dtype=np.int32).reshape( + (_DIM_1, 1)) # shape (564096,1) + gradient = np.fromfile(os.path.join(_INPUT_PATH, "gradient.bin"), dtype=np.float32).reshape( + (_DIM_1, _DIM_2)) # shape (564096,32) + input_m = np.fromfile(os.path.join(_INPUT_PATH, "inputM.bin"), dtype=np.float32).reshape( + (_DIM_0, _DIM_2)) # shape (2000000,32) + input_v = np.fromfile(os.path.join(_INPUT_PATH, "inputV.bin"), dtype=np.float32).reshape( + (_DIM_0, _DIM_2)) # shape (2000000,32) + input_var = np.fromfile(os.path.join(_INPUT_PATH, "inputVar.bin"), dtype=np.float32).reshape( + (_DIM_0, _DIM_2)) # shape (2000000,32) + + old_m_slice = _gather(input_m, indices) # shape(564096,32) + old_m_slice = np.array(old_m_slice).astype(np.float32) # + update_m = beta1 * old_m_slice + (1 - beta1) * gradient + out_m = _scatter_nd_update(input_m, indices, update_m) + + old_v_slice = _gather(input_v, indices) + old_v_slice = np.array(old_v_slice).astype(np.float32) + update_v = beta2 * old_v_slice + (1 - beta2) * np.square(gradient) + out_v = _scatter_nd_update(input_v, indices, update_v) + + denominator_slice = np.sqrt(np.abs(update_v)) + epsilon + update_var = np.divide(-lr * update_m, denominator_slice) + out_var = _scatter_nd_add(input_var, indices, update_var) + + return out_m, out_v, out_var + + +def _gen_input_and_golden_data(): + # 产生输入数据 + _gen_input_data() + + # 产生真值数据 + out_m, out_v, out_var = _gen_golden_data() + if not os.path.exists(_OUTPUT_PATH): + os.makedirs(_OUTPUT_PATH) + out_m.tofile(os.path.join(_OUTPUT_PATH, "goldenOutputM.bin")) + out_v.tofile(os.path.join(_OUTPUT_PATH, "goldenOutputV.bin")) + out_var.tofile(os.path.join(_OUTPUT_PATH, "goldenOutputVar.bin")) + + +if __name__ == "__main__": + _gen_input_and_golden_data() diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/verify_result.py b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/verify_result.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc516dbb8ee630ae251d7a71e03c9cfbf688cc5 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/scripts/verify_result.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import numpy as np + +_LOSS_THRESHOLD = 1e-6 # 容忍偏差,一般fp16要求绝对误差和相对误差均不超过万分之一 +_MINIMUM = 10e-10 + +logging.getLogger().setLevel(logging.INFO) + + +def verify_result(real_result, golden): + real_result = np.fromfile(real_result, dtype=np.float32) # 从bin文件读取实际运算结果 + golden = np.fromfile(golden, dtype=np.float32) # 从bin文件读取预期运算结果 + result = np.abs(real_result - golden) # 计算运算结果和预期结果偏差 + deno = np.maximum(np.abs(real_result), np.abs(golden)) # 获取最大值并组成新数组 + result_atol = np.less_equal(result, _LOSS_THRESHOLD) # 计算绝对误差 + result_rtol = np.less_equal(result / np.add(deno, _MINIMUM), _LOSS_THRESHOLD) # 计算相对误差 + if not result_rtol.all() and not result_atol.all(): + # 误差超出预期时返回打印错误,返回对比失败 + if np.sum(result_rtol == False) > real_result.size * _LOSS_THRESHOLD \ + and np.sum(result_atol == False) > real_result.size * _LOSS_THRESHOLD: + logging.error("[ERROR] output verify result error.") + return False + logging.info("output verify pass.") + return True + + +if __name__ == '__main__': + logging.info("start verify outputM.") + verify_result("output/outputM.bin", "output/goldenOutputM.bin") + logging.info("start verify outputV.") + verify_result("output/outputV.bin", "output/goldenOutputV.bin") + logging.info("start verify outputVar.") + verify_result("output/outputVar.bin", "output/goldenOutputVar.bin") diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/CMakeLists.txt b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c2366f4a78dd73ed67d667d965adaefd5a32eed4 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/CMakeLists.txt @@ -0,0 +1,69 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + +# CMake lowest version requirement +cmake_minimum_required(VERSION 3.5.1) + +# project information +project(acl_execute_lazy_adam) + +# Compile options +add_compile_options(-std=c++11) + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "../output") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "../output") + +set(INC_PATH $ENV{DDK_PATH}) + +if (NOT DEFINED ENV{DDK_PATH}) + set(INC_PATH "/usr/local/Ascend/ascend-toolkit/latest") + message(STATUS "set default INC_PATH: ${INC_PATH}") +else () + message(STATUS "env INC_PATH: ${INC_PATH}") +endif () + +set(CUST_PKG_PATH "${INC_PATH}/opp/vendors/mxrec_fused_lazy_adam/op_api") + +set(LIB_PATH $ENV{NPU_HOST_LIB}) + +# Dynamic libraries in the stub directory can only be used for compilation +if (NOT DEFINED ENV{NPU_HOST_LIB}) + set(LIB_PATH "/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64/stub/") + set(LIB_PATH1 "/usr/local/Ascend/ascend-toolkit/latest/atc/lib64/stub/") + message(STATUS "set default LIB_PATH: ${LIB_PATH}") +else () + message(STATUS "env LIB_PATH: ${LIB_PATH}") +endif () + +set(AUTO_GEN_PATH "../../lazy_adam/build_out/autogen") +# Header path +include_directories( + ${INC_PATH}/runtime/include + ${INC_PATH}/atc/include + ../inc + ${CUST_PKG_PATH}/include + ${AUTO_GEN_PATH} +) + +# add host lib path +link_directories( + ${LIB_PATH} + ${LIB_PATH1} + ${CUST_PKG_PATH}/lib +) + +add_executable(execute_op + main.cpp + operator_desc.cpp + op_runner.cpp + common.cpp +) + +target_link_libraries(execute_op + ascendcl + cust_opapi + acl_op_compiler + nnopbase + stdc++ +) + +install(TARGETS execute_op DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/common.cpp b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e2cd6865ab4e4c8b0b58aec90987cb85b235ca70 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/common.cpp @@ -0,0 +1,84 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "common.h" + +namespace AclnnLazyAdam { + bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize) + { + struct stat sBuf; + int fileStatus = stat(filePath.data(), &sBuf); + if (fileStatus == -1) { + ERROR_LOG("failed to get file %s", filePath.c_str()); + return false; + } + if (S_ISREG(sBuf.st_mode) == 0) { + ERROR_LOG("%s is not a file, please enter a file", filePath.c_str()); + return false; + } + + std::ifstream file; + file.open(filePath, std::ios::binary); + if (!file.is_open()) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0) { + ERROR_LOG("file size is 0"); + file.close(); + return false; + } + if (size > bufferSize) { + ERROR_LOG("file size is larger than buffer size"); + file.close(); + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + file.close(); + return true; + } + + bool WriteFile(const std::string &filePath, const void *buffer, size_t size) + { + if (buffer == nullptr) { + ERROR_LOG("Write file failed. buffer is nullptr"); + return false; + } + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + auto writeSize = write(fd, buffer, size); + (void) close(fd); + if (writeSize != size) { + ERROR_LOG("Write file Failed."); + return false; + } + + return true; + } +} \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/main.cpp b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c42539961be4365a99fbb055d8a6e78e6b6d1c13 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/main.cpp @@ -0,0 +1,228 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "acl/acl.h" +#include "aclnn_lazy_adam.h" +#include "common.h" +#include "op_runner.h" + +using namespace AclnnLazyAdam; + +bool g_isDevice = false; +int g_deviceId = 0; +namespace { + constexpr int DIM0 = 2000000; // inputM inputV inputVar 的行数 + constexpr int DIM1 = 564096; // indices长度 + constexpr int DIM2 = 32; // inputM inputV inputVar gradient等每行的数据个数 + constexpr int INPUT_M_INDEX = 2; + constexpr int INPUT_V_INDEX = 3; + constexpr int INPUT_VAR_INDEX = 4; + constexpr int LEARNING_RATE_INDEX = 5; + constexpr int OUTPUT_M_INDEX = 0; + constexpr int OUTPUT_V_INDEX = 1; + constexpr int OUTPUT_VAR_INDEX = 2; + constexpr float LEARNING_RATE = 0.001; + constexpr float BETA1 = 0.9; + constexpr float BETA2 = 0.999; + constexpr float EPSILON = 1e-7; + const char* READ_ERROR_INFO = "read input file error, please check whether file exist and access rights is correct"; + const char* WRITE_ERROR_INFO = "write output file error, please check access rights is correct"; + + OperatorDesc CreateOpDesc() + { + std::vector indicesShape{DIM1, 1}; + std::vector gradientShape{DIM1, DIM2}; + std::vector inputMShape{DIM0, DIM2}; // inputM inputV inputVar 的shape相同 + std::vector learningRateShape{1}; + aclDataType dataType = ACL_FLOAT; + aclDataType indexDataType = ACL_INT32; + aclFormat format = ACL_FORMAT_ND; + OperatorDesc opDesc; + opDesc.AddInputTensorDesc(dataType, gradientShape.size(), gradientShape.data(), format); + opDesc.AddInputTensorDesc(indexDataType, indicesShape.size(), indicesShape.data(), format); + opDesc.AddInputTensorDesc(dataType, inputMShape.size(), inputMShape.data(), format); // inputM + opDesc.AddInputTensorDesc(dataType, inputMShape.size(), inputMShape.data(), format); // inputV + opDesc.AddInputTensorDesc(dataType, inputMShape.size(), inputMShape.data(), format); // inputVar + opDesc.AddInputTensorDesc(dataType, learningRateShape.size(), learningRateShape.data(), + format); // learningRate + opDesc.beta1 = BETA1; + opDesc.beta2 = BETA2; + opDesc.epsilon = EPSILON; + return opDesc; + } + + bool SetInputData(OpRunner& runner) + { + size_t fileSize = 0; + if (!ReadFile("../input/gradient.bin", fileSize, runner.GetInputBuffer(0), runner.GetInputSize(0))) { + throw std::runtime_error(READ_ERROR_INFO); + } + if (!ReadFile("../input/indices.bin", fileSize, runner.GetInputBuffer(1), runner.GetInputSize(1))) { + throw std::runtime_error(READ_ERROR_INFO); + } + if (!ReadFile("../input/inputM.bin", fileSize, runner.GetInputBuffer(INPUT_M_INDEX), + runner.GetInputSize(INPUT_M_INDEX))) { + throw std::runtime_error(READ_ERROR_INFO); + } + if (!ReadFile("../input/inputV.bin", fileSize, runner.GetInputBuffer(INPUT_V_INDEX), + runner.GetInputSize(INPUT_V_INDEX))) { + throw std::runtime_error(READ_ERROR_INFO); + } + if (!ReadFile("../input/inputVar.bin", fileSize, runner.GetInputBuffer(INPUT_VAR_INDEX), + runner.GetInputSize(INPUT_VAR_INDEX))) { + throw std::runtime_error(READ_ERROR_INFO); + } + if (!ReadFile("../input/learningRate.bin", fileSize, runner.GetInputBuffer(LEARNING_RATE_INDEX), + runner.GetInputSize(LEARNING_RATE_INDEX))) { + throw std::runtime_error(READ_ERROR_INFO); + } + INFO_LOG("Set input success"); + return true; + } + + bool ProcessOutputData(OpRunner& runner) + { + // 保存输出数据 由于输出仅有hostOutputs_数据,未设置outputDesc,因此数据size从inputTensor获取 + if (!WriteFile("../output/outputM.bin", runner.GetOutputBuffer(OUTPUT_M_INDEX), + runner.GetInputSize(INPUT_M_INDEX))) { + throw std::runtime_error(WRITE_ERROR_INFO); + } + if (!WriteFile("../output/outputV.bin", runner.GetOutputBuffer(OUTPUT_V_INDEX), + runner.GetInputSize(INPUT_V_INDEX))) { + throw std::runtime_error(WRITE_ERROR_INFO); + } + if (!WriteFile("../output/outputVar.bin", runner.GetOutputBuffer(OUTPUT_VAR_INDEX), + runner.GetInputSize(INPUT_VAR_INDEX))) { + throw std::runtime_error(WRITE_ERROR_INFO); + } + INFO_LOG("Write output success"); + return true; + } + + void DestroyResource() + { + bool flag = false; + if (aclrtResetDevice(g_deviceId) != ACL_SUCCESS) { + ERROR_LOG("Reset device %d failed", g_deviceId); + flag = true; + } + INFO_LOG("Reset Device success"); + if (aclFinalize() != ACL_SUCCESS) { + ERROR_LOG("Finalize acl failed"); + flag = true; + } + if (flag) { + ERROR_LOG("Destroy resource failed"); + } else { + INFO_LOG("Destroy resource success"); + } + } + + bool InitResource() + { + std::string output = "../output"; + if (access(output.c_str(), 0) == -1) { + int ret = mkdir(output.c_str(), 0700); + if (ret == 0) { + INFO_LOG("Make output directory successfully"); + } else { + ERROR_LOG("Make output directory fail"); + return false; + } + } + + // acl.json is dump or profiling config file + if (aclInit(NULL) != ACL_SUCCESS) { + ERROR_LOG("acl init failed"); + return false; + } + + if (aclrtSetDevice(g_deviceId) != ACL_SUCCESS) { + ERROR_LOG("Set device failed. g_deviceId is %d", g_deviceId); + (void) aclFinalize(); + return false; + } + INFO_LOG("Set device[%d] success", g_deviceId); + + // runMode is ACL_HOST which represents app is running in host + // runMode is ACL_DEVICE which represents app is running in device + aclrtRunMode runMode; + if (aclrtGetRunMode(&runMode) != ACL_SUCCESS) { + ERROR_LOG("Get run mode failed"); + DestroyResource(); + return false; + } + g_isDevice = (runMode == ACL_DEVICE); + INFO_LOG("Get RunMode[%d] success", runMode); + + return true; + } + + bool RunOp() + { + // create op desc + OperatorDesc opDesc = CreateOpDesc(); + + // create Runner + OpRunner opRunner(&opDesc); + if (!opRunner.Init()) { + ERROR_LOG("Init OpRunner failed"); + return false; + } + + // Load inputs + if (!SetInputData(opRunner)) { + ERROR_LOG("Set input data failed"); + return false; + } + + // Run op + if (!opRunner.RunOp()) { + ERROR_LOG("Run op failed"); + return false; + } + + // process output data + if (!ProcessOutputData(opRunner)) { + ERROR_LOG("Process output data failed"); + return false; + } + INFO_LOG("Run op success"); + return true; + } +} + +int main(int argc, char** argv) +{ + if (!InitResource()) { + ERROR_LOG("Init resource failed"); + return FAILED; + } + INFO_LOG("Init resource success"); + + if (!RunOp()) { + DestroyResource(); + return FAILED; + } + DestroyResource(); + return SUCCESS; +} diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b9b51feb5ffcfc41535e5f4fbdcfdb4944d5d22 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/op_runner.cpp @@ -0,0 +1,354 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include "op_runner.h" + +#include +#include + +#include "acl/acl_op_compiler.h" +#include "aclnn_lazy_adam.h" +#include "common.h" + +extern bool g_isDevice; + +namespace AclnnLazyAdam { + using namespace std; + constexpr int PRINT_OUT_WIDTH = 10; + constexpr int PRINT_OUT_PRECISION = 4; + constexpr int STREAM_TIMEOUT = 5000; // 等待Stream任务完成,超时时间单位:ms + constexpr int OUTPUT_SIZE = 3; + constexpr int INPUT_TENSOR_OFFSET = 2; + + OpRunner::OpRunner(OperatorDesc* opDesc) : opDesc_(opDesc) + { + numInputs_ = opDesc->inputDesc.size(); + numOutputs_ = opDesc->outputDesc.size(); + } + + OpRunner::~OpRunner() + { + for (size_t i = 0; i < numInputs_; ++i) { + (void) aclDestroyTensor(inputTensor_[i]); + (void) aclDestroyDataBuffer(inputBuffers_[i]); + (void) aclrtFree(devInputs_[i]); + if (g_isDevice) { + (void) aclrtFree(hostInputs_[i]); + } else { + (void) aclrtFreeHost(hostInputs_[i]); + } + } + for (size_t i = 0; i < numOutputs_; ++i) { + if (g_isDevice) { + (void) aclrtFree(hostOutputs_[i]); + } else { + (void) aclrtFreeHost(hostOutputs_[i]); + } + } + } + + bool OpRunner::InitOutputInfo() + { + // 手动修改输出数据实现,仅申请host上的输出数据空间,析构出需同时适配 + numOutputs_ = OUTPUT_SIZE; + for (size_t i = 0; i < numOutputs_; ++i) { + int inputTensorIndex = i + INPUT_TENSOR_OFFSET; + auto size = GetInputSize(inputTensorIndex); + + void* hostOutput = nullptr; + if (g_isDevice) { + if (aclrtMalloc(&hostOutput, size, ACL_MEM_MALLOC_NORMAL_ONLY) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for output[%zu] failed", i); + return false; + } + } else { + if (aclrtMallocHost(&hostOutput, size) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for output[%zu] failed", i); + return false; + } + } + if (hostOutput == nullptr) { + ERROR_LOG("Malloc host memory for output[%zu] failed", i); + return false; + } + hostOutputs_.emplace_back(hostOutput); + } + return true; + } + + bool OpRunner::Init() + { + for (size_t i = 0; i < numInputs_; ++i) { + auto size = GetInputSize(i); + void* devMem = nullptr; + if (aclrtMalloc(&devMem, size, ACL_MEM_MALLOC_NORMAL_ONLY) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for input[%zu] failed", i); + return false; + } + devInputs_.emplace_back(devMem); + inputBuffers_.emplace_back(aclCreateDataBuffer(devMem, size)); + + void* hostInput = nullptr; + if (g_isDevice) { + if (aclrtMalloc(&hostInput, size, ACL_MEM_MALLOC_NORMAL_ONLY) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for input[%zu] failed", i); + return false; + } + } else { + if (aclrtMallocHost(&hostInput, size) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory for input[%zu] failed", i); + return false; + } + } + if (hostInput == nullptr) { + ERROR_LOG("Malloc memory for input[%zu] failed", i); + return false; + } + hostInputs_.emplace_back(hostInput); + + aclTensor* inputTensor = + aclCreateTensor(GetInputShape(i).data(), GetInputNumDims(i), GetInputDataType(i), nullptr, 0, + GetInputFormat(i), GetInputShape(i).data(), GetInputNumDims(i), devInputs_[i]); + if (inputTensor == nullptr) { + ERROR_LOG("Create Tensor for input[%zu] failed", i); + return false; + } + inputTensor_.emplace_back(inputTensor); + } + + return InitOutputInfo(); + } + + const size_t OpRunner::NumInputs() + { + return numInputs_; + } + + const size_t OpRunner::NumOutputs() + { + return numOutputs_; + } + + const size_t OpRunner::GetInputSize(size_t index) const + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return 0; + } + return aclGetTensorDescSize(opDesc_->inputDesc[index]); + } + + const size_t OpRunner::GetInputNumDims(size_t index) const + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return 0; + } + return aclGetTensorDescNumDims(opDesc_->inputDesc[index]); + } + + aclDataType OpRunner::GetInputDataType(size_t index) const + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return ACL_DT_UNDEFINED; + } + return aclGetTensorDescType(opDesc_->inputDesc[index]); + } + + aclFormat OpRunner::GetInputFormat(size_t index) const + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return ACL_FORMAT_UNDEFINED; + } + return aclGetTensorDescFormat(opDesc_->inputDesc[index]); + } + + std::vector OpRunner::GetInputShape(size_t index) const + { + std::vector ret; + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return ret; + } + + auto desc = opDesc_->inputDesc[index]; + for (size_t i = 0; i < aclGetTensorDescNumDims(desc); ++i) { + int64_t dimSize; + if (aclGetTensorDescDimV2(desc, i, &dimSize) != ACL_SUCCESS) { + ERROR_LOG("get dims from tensor desc failed. dims index = %zu", i); + ret.clear(); + return ret; + } + ret.emplace_back(dimSize); + } + return ret; + } + + size_t OpRunner::GetOutputSize(size_t index) const + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return 0; + } + return aclGetTensorDescSize(opDesc_->outputDesc[index]); + } + + const size_t OpRunner::GetOutputNumDims(size_t index) const + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return 0; + } + return aclGetTensorDescNumDims(opDesc_->outputDesc[index]); + } + + aclDataType OpRunner::GetOutputDataType(size_t index) const + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return ACL_DT_UNDEFINED; + } + return aclGetTensorDescType(opDesc_->outputDesc[index]); + } + + aclFormat OpRunner::GetOutputFormat(size_t index) const + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return ACL_FORMAT_UNDEFINED; + } + + return aclGetTensorDescFormat(opDesc_->outputDesc[index]); + } + + std::vector OpRunner::GetOutputShape(size_t index) const + { + std::vector ret; + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return ret; + } + + auto desc = opDesc_->outputDesc[index]; + for (size_t i = 0; i < aclGetTensorDescNumDims(desc); ++i) { + int64_t dimSize; + if (aclGetTensorDescDimV2(desc, i, &dimSize) != ACL_SUCCESS) { + ERROR_LOG("get dims from tensor desc failed. dims index = %zu", i); + ret.clear(); + return ret; + } + ret.emplace_back(dimSize); + } + return ret; + } + + size_t OpRunner::GetInputElementCount(size_t index) const + { + if (index >= opDesc_->inputDesc.size()) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return 0; + } + + return aclGetTensorDescElementCount(opDesc_->inputDesc[index]); + } + + size_t OpRunner::GetOutputElementCount(size_t index) const + { + if (index >= opDesc_->outputDesc.size()) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return 0; + } + return aclGetTensorDescElementCount(opDesc_->outputDesc[index]); + } + + bool OpRunner::RunOp() + { + for (size_t i = 0; i < numInputs_; ++i) { + auto size = GetInputSize(i); + aclrtMemcpyKind kind = ACL_MEMCPY_HOST_TO_DEVICE; + if (g_isDevice) { + kind = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + if (aclrtMemcpy(devInputs_[i], size, hostInputs_[i], size, kind) != ACL_SUCCESS) { + ERROR_LOG("Copy input[%zu] failed", i); + return false; + } + INFO_LOG("Copy input[%zu] success", i); + } + + aclrtStream stream = nullptr; + if (aclrtCreateStream(&stream) != ACL_SUCCESS) { + ERROR_LOG("Create stream failed"); + return false; + } + INFO_LOG("Create stream success"); + + size_t workspaceSize = 0; + aclOpExecutor* handle = nullptr; + auto ret = aclnnLazyAdamGetWorkspaceSize(inputTensor_[0], inputTensor_[1], inputTensor_[2], inputTensor_[3], + inputTensor_[4], inputTensor_[5], opDesc_->beta1, opDesc_->beta2, + opDesc_->epsilon, &workspaceSize, &handle); + if (ret != ACL_SUCCESS) { + (void) aclrtDestroyStream(stream); + ERROR_LOG("Get Operator Workspace failed. error code is %d", static_cast(ret)); + return false; + } + INFO_LOG("Execute aclnnAddCustomGetWorkspaceSize success, workspace size %lu", workspaceSize); + + void* workspace = nullptr; + if (workspaceSize != 0) { + if (aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_NORMAL_ONLY) != ACL_SUCCESS) { + ERROR_LOG("Malloc device memory failed"); + } + } + + ret = aclnnLazyAdam(workspace, workspaceSize, handle, stream); + if (ret != ACL_SUCCESS) { + (void) aclrtDestroyStream(stream); + ERROR_LOG("Execute Operator failed. error code is %d", static_cast(ret)); + return false; + } + INFO_LOG("Execute aclnnLazyAdam success"); + + ret = aclrtSynchronizeStreamWithTimeout(stream, STREAM_TIMEOUT); + if (ret != SUCCESS) { + ERROR_LOG("Synchronize stream failed. error code is %d", static_cast(ret)); + (void) aclrtDestroyStream(stream); + return false; + } + INFO_LOG("Synchronize stream success"); + + // 把输入数据:inputM inputV inputVar 作为输出数据拷贝出来 + for (size_t i = 0; i < OUTPUT_SIZE; ++i) { + int inputTensorIndex = i + INPUT_TENSOR_OFFSET; // 加上输入tensor偏移值 + auto size = GetInputSize(inputTensorIndex); + aclrtMemcpyKind kind = ACL_MEMCPY_DEVICE_TO_HOST; + if (g_isDevice) { + kind = ACL_MEMCPY_DEVICE_TO_DEVICE; + } + if (aclrtMemcpy(hostOutputs_[i], size, devInputs_[inputTensorIndex], size, kind) != ACL_SUCCESS) { + INFO_LOG("Copy output[%zu] success", i); + (void) aclrtDestroyStream(stream); + return false; + } + INFO_LOG("Copy output[%zu] success", i); + } + + (void) aclrtDestroyStream(stream); + return true; + } +} // namespace AclnnLazyAdam \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/operator_desc.cpp b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/operator_desc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13602e17a3dd0b722d096a876336bfcaff3361e8 --- /dev/null +++ b/cust_op/fused_lazy_adam/aclnn_lazy_adam_test/src/operator_desc.cpp @@ -0,0 +1,57 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include "operator_desc.h" + +#include "common.h" + +namespace AclnnLazyAdam { +using namespace std; + +OperatorDesc::OperatorDesc() {} + +OperatorDesc::~OperatorDesc() +{ + for (auto* desc : inputDesc) { + aclDestroyTensorDesc(desc); + } + for (auto* desc : outputDesc) { + aclDestroyTensorDesc(desc); + } +} + +OperatorDesc& OperatorDesc::AddInputTensorDesc(aclDataType dataType, int numDims, const int64_t* dims, aclFormat format) +{ + aclTensorDesc* desc = aclCreateTensorDesc(dataType, numDims, dims, format); + if (desc == nullptr) { + ERROR_LOG("create tensor failed"); + return *this; + } + inputDesc.emplace_back(desc); + return *this; +} + +OperatorDesc& OperatorDesc::AddOutputTensorDesc(aclDataType dataType, int numDims, const int64_t* dims, + aclFormat format) +{ + aclTensorDesc* desc = aclCreateTensorDesc(dataType, numDims, dims, format); + if (desc == nullptr) { + ERROR_LOG("create tensor failed"); + return *this; + } + outputDesc.emplace_back(desc); + return *this; +} +} // namespace AclnnLazyAdam \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/lazy_adam.json b/cust_op/fused_lazy_adam/lazy_adam.json new file mode 100644 index 0000000000000000000000000000000000000000..e6fc2c001c0fc6d889a828df637fa66c78cb2e7b --- /dev/null +++ b/cust_op/fused_lazy_adam/lazy_adam.json @@ -0,0 +1,117 @@ +[ + { + "op": "LazyAdam", + "language": "cpp", + "input_desc": [ + { + "name": "gradient", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + }, + { + "name": "indices", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "int32" + ] + }, + { + "name": "inputM", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + }, + { + "name": "inputV", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + }, + { + "name": "inputVar", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + }, + { + "name": "lr", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + } + ], + "output_desc": [ + { + "name": "inputM", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + }, + { + "name": "inputV", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + }, + { + "name": "inputVar", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "fp32" + ] + } + ], + "attr": [ + { + "name": "beta1", + "param_type": "required", + "type": "float" + }, + { + "name": "beta2", + "param_type": "required", + "type": "float" + }, + { + "name": "epsilon", + "param_type": "required", + "type": "float" + } + ] + } +] \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp b/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c288729d211a9506bbf1063f6971e9a77cbb2fd --- /dev/null +++ b/cust_op/fused_lazy_adam/op_host/lazy_adam.cpp @@ -0,0 +1,222 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include "lazy_adam_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +namespace optiling { +constexpr int BLOCK_SIZE = 32; +constexpr int RESERVE_UB_SIZE = 20 * 1024; +constexpr int DATA_NUM_PER_COMPUTE = 8; +constexpr int32_t USR_SIZE = 256; +constexpr int32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; + +template +static ge::graphStatus CheckNullPointer(T* pointer, const char* errorMessage) +{ + if (pointer == nullptr) { + printf("%s nullptr\n", errorMessage); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus LazyAdamTilingFunc(gert::TilingContext* context) +{ + size_t* currentWorkspace = context->GetWorkspaceSizes(1); + if (CheckNullPointer(currentWorkspace, "currentWorkspace") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = SYS_WORKSPACE_SIZE + USR_SIZE; + + LazyAdamTilingData tiling; + const gert::StorageShape* indicesShape = context->GetInputShape(1); + const gert::StorageShape* inputMShape = context->GetInputShape(2); + uint64_t dim0 = inputMShape->GetStorageShape().GetDim(0); + uint64_t dim1 = indicesShape->GetStorageShape().GetDim(0); + uint64_t dim2 = inputMShape->GetStorageShape().GetDim(1); + ge::DataType inputMDtype = context->GetInputDesc(2)->GetDataType(); + int inputMDtypeSize = ge::GetSizeByDataType(inputMDtype); + ge::DataType indicesDtype = context->GetInputDesc(1)->GetDataType(); + int indicesDtypeSize = ge::GetSizeByDataType(indicesDtype); + + auto attrs = context->GetAttrs(); + + float beta1 = *attrs->GetAttrPointer(0); + float beta2 = *attrs->GetAttrPointer(1); + float epsilon = *attrs->GetAttrPointer(2); + + auto platformInfo = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t coreNum = platformInfo.GetCoreNum(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + uint64_t ub; + platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVE_UB_SIZE; + // ub大小除以每行的数据大小,得到每次处理的行数 + uint64_t row = ub / (dim2 * inputMDtypeSize * DATA_NUM_PER_COMPUTE + 1 * indicesDtypeSize); + if (row > dim1) { + row = dim1; + } + + // 保证申请的内存是32的倍数并且向上取整 计算方式:(num+31)/32*32 + uint64_t indicesAllocSize = (row * indicesDtypeSize + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + uint64_t otherAllocSize = (row * inputMDtypeSize * dim2 + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; + // 前 CORE_NUM - 1 个核分配的任务量 + uint64_t batch = dim1 / coreNum; + // 实际使用的核数 + context->SetBlockDim(coreNum); + uint64_t loopCount = batch / row; // CORE_NUM - 1 个核的任务量,除以UB每一次能处理的数据,得到处理次数 + uint64_t rowLeft = batch - row * loopCount; // UB处理 loopCount 那么多次后,分给当前core剩下的数据量 + + // 最后一个核分配的任务量 + uint64_t batchTail = dim1 - batch * (coreNum - 1); // phy 该写法适配了dim1刚好整除coreNum的情况 + uint64_t loopCountTail = batchTail / row; + uint64_t rowLeftTail = batchTail - row * loopCountTail; + + tiling.set_beta1(beta1); + tiling.set_beta2(beta2); + tiling.set_epsilon(epsilon); + tiling.set_dim0(dim0); + tiling.set_dim1(dim1); + tiling.set_dim2(dim2); + tiling.set_row(row); // 每个ai core一次能分配的数据行数 + tiling.set_indicesAllocSize(indicesAllocSize); // indices大小,用于申请空间 + tiling.set_otherAllocSize(otherAllocSize); // 入参中非indices要申请的空间大小 + tiling.set_batch(batch); // 前CORE_NUM - 1个核分配的任务量 + tiling.set_loopCount(loopCount); // 前CORE_NUM - 1 个核内循环处理次数 + tiling.set_rowLeft(rowLeft); // 前CORE_NUM - 1 个核, 核内处理 loopCount 次后,分给当前core剩下的数据量 + tiling.set_loopCountTail(loopCountTail); // 最后一个核,核内循环次数 + tiling.set_rowLeftTail(rowLeftTail); // 最后一个核,核内循环loopCountTail次后,剩余数据量 + tiling.set_coreNum(coreNum); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus LazyAdamInferShape(gert::InferShapeContext* context) +{ + if (optiling::CheckNullPointer(context, "context") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + gert::Shape* outputMShape = context->GetOutputShape(0); + if (optiling::CheckNullPointer(outputMShape, "outputMShape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + const gert::Shape* inputMShape = context->GetInputShape(2); + if (optiling::CheckNullPointer(inputMShape, "inputMShape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + *outputMShape = *inputMShape; + + gert::Shape* outputVShape = context->GetOutputShape(1); + if (optiling::CheckNullPointer(outputVShape, "outputVShape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + const gert::Shape* inputVShape = context->GetInputShape(3); + if (optiling::CheckNullPointer(inputVShape, "inputVShape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + *outputVShape = *inputVShape; + + gert::Shape* outputVarShape = context->GetOutputShape(2); + if (optiling::CheckNullPointer(outputVarShape, "outputVarShape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + const gert::Shape* inputVarShape = context->GetInputShape(4); + if (optiling::CheckNullPointer(inputVarShape, "inputVarShape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + *outputVarShape = *inputVarShape; + + return GRAPH_SUCCESS; +} + +static ge::graphStatus LazyAdamInferDataType(gert::InferDataTypeContext* context) +{ + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class LazyAdam : public OpDef { +public: + explicit LazyAdam(const char* name) : OpDef(name) + { + this->Input("gradient") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("indices") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("inputM") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("inputV") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("inputVar") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("lr") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("inputM") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("inputV") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("inputVar") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("beta1").Float(); + this->Attr("beta2").Float(); + this->Attr("epsilon").Float(); + this->SetInferShape(ge::LazyAdamInferShape).SetInferDataType(ge::LazyAdamInferDataType); + this->AICore().SetTiling(optiling::LazyAdamTilingFunc); + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910c"); + } +}; + +OP_ADD(LazyAdam); +} // namespace ops diff --git a/cust_op/fused_lazy_adam/op_host/lazy_adam_tiling.h b/cust_op/fused_lazy_adam/op_host/lazy_adam_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..4f1534a4d4c25bee0f16b426d2b7ddd5a0c70895 --- /dev/null +++ b/cust_op/fused_lazy_adam/op_host/lazy_adam_tiling.h @@ -0,0 +1,41 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef LAZY_ADAM_TILING_H +#define LAZY_ADAM_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(LazyAdamTilingData) +TILING_DATA_FIELD_DEF(float, beta1); +TILING_DATA_FIELD_DEF(float, beta2); +TILING_DATA_FIELD_DEF(float, epsilon); +TILING_DATA_FIELD_DEF(int32_t, dim0); +TILING_DATA_FIELD_DEF(int32_t, dim1); +TILING_DATA_FIELD_DEF(int32_t, dim2); +TILING_DATA_FIELD_DEF(int32_t, row); +TILING_DATA_FIELD_DEF(int32_t, indicesAllocSize); +TILING_DATA_FIELD_DEF(int32_t, otherAllocSize); +TILING_DATA_FIELD_DEF(int32_t, batch); +TILING_DATA_FIELD_DEF(int32_t, loopCount); +TILING_DATA_FIELD_DEF(int32_t, rowLeft); +TILING_DATA_FIELD_DEF(int32_t, loopCountTail); +TILING_DATA_FIELD_DEF(int32_t, rowLeftTail); +TILING_DATA_FIELD_DEF(int32_t, coreNum); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(LazyAdam, LazyAdamTilingData) +} // namespace optiling +#endif // LAZY_ADAM_TILING_H \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp b/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e0ad8e4542247e825c06425bfa2ce4cb4f73675f --- /dev/null +++ b/cust_op/fused_lazy_adam/op_kernel/lazy_adam.cpp @@ -0,0 +1,243 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include "kernel_operator.h" + +using namespace AscendC; + +template +class LazyAdam { +public: + __aicore__ inline LazyAdam() {} + + // 初始化函数,完成内存初始化相关操作 + __aicore__ inline void Init(GM_ADDR gradient, GM_ADDR indices, GM_ADDR inputM, GM_ADDR inputV, GM_ADDR inputVar, + GM_ADDR lr, GM_ADDR inputMRef, GM_ADDR inputVRef, GM_ADDR inputVarRef, float beta1, + float beta2, float epsilon, int32_t dim0, int32_t dim1, int32_t dim2, int32_t row, + int32_t indicesAllocSize, int32_t otherAllocSize, int32_t batch, int32_t loopCount, + int32_t rowLeft, int32_t loopCountTail, int32_t rowLeftTail, int32_t coreNum) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + // 属性赋值 + this->beta1 = beta1; + this->beta2 = beta2; + this->epsilon = epsilon; + // tiling 数据赋值 + this->dim0 = dim0; + this->dim1 = dim1; + this->dim2 = dim2; + this->row = row; + this->batch = batch; + this->loopCount = loopCount; + this->rowLeft = rowLeft; + this->loopCountTail = loopCountTail; + this->rowLeftTail = rowLeftTail; + this->coreNum = coreNum; + // 输入的 gm shape 大小 + int32_t shape = this->dim0 * this->dim2; + int32_t shapeIndices = this->dim1 * 1; + int32_t shapeGradient = this->dim1 * this->dim2; + this->gmGradient.SetGlobalBuffer((__gm__ T*)gradient + this->batch * this->dim2 * get_block_idx(), + shapeGradient); + this->gmIndices.SetGlobalBuffer((__gm__ int32_t*)indices + this->batch * get_block_idx(), shapeIndices); + + this->gmInputM.SetGlobalBuffer((__gm__ T*)inputM, shape); + this->gmInputV.SetGlobalBuffer((__gm__ T*)inputV, shape); + this->gmInputVar.SetGlobalBuffer((__gm__ T*)inputVar, shape); + + this->gmLearningRate.SetGlobalBuffer((__gm__ T*)lr, sizeof(float)); + this->lr = this->gmLearningRate.GetValue(0); + + // 将输出地址指向输入地址 + inputMRef = inputM; + inputVRef = inputV; + inputVarRef = inputVar; + + // 单次循环申请的 ub 大小, 32位对齐后的大小 + this->pipe.InitBuffer(this->inQueGradient, 1, otherAllocSize); + this->pipe.InitBuffer(this->inQueIndices, 1, indicesAllocSize); + this->pipe.InitBuffer(this->queMSlice, 1, otherAllocSize); + this->pipe.InitBuffer(this->queVSlice, 1, otherAllocSize); + this->pipe.InitBuffer(this->queVarSlice, 1, otherAllocSize); + + this->pipe.InitBuffer(this->calcBufM, otherAllocSize); + this->updateM = this->calcBufM.template Get(); + + this->pipe.InitBuffer(this->calcBufV, otherAllocSize); + this->updateV = this->calcBufV.template Get(); + + this->pipe.InitBuffer(this->calcBufVar, otherAllocSize); + this->updateVar = this->calcBufVar.template Get(); + + this->pipe.InitBuffer(this->calcBuf, otherAllocSize); + this->temp = this->calcBuf.template Get(); + } + + // 核心处理函数,实现算子逻辑,调用私有成员函数CopyIn、Compute、CopyOut完成矢量算子的三级流水操作 + __aicore__ inline void Process() + { + if (get_block_idx() == this->coreNum - 1) { + for (int32_t i = 0; i < this->loopCountTail; i++) { + CopyIn(i, this->row); + Compute(i, this->row); + } + // 尾块处理 + if (this->rowLeft > 0) { + CopyIn(this->loopCountTail, this->rowLeftTail); + Compute(this->loopCountTail, this->rowLeftTail); + } + } else { + for (int32_t i = 0; i < this->loopCount; i++) { + CopyIn(i, this->row); + Compute(i, this->row); + } + // 尾块处理 + if (this->rowLeft > 0) { + CopyIn(this->loopCount, this->rowLeft); + Compute(this->loopCount, this->rowLeft); + } + } + } + +private: + // 搬入函数,完成CopyIn阶段的处理,被核心Process函数调用 + __aicore__ inline void CopyIn(int32_t progress, int32_t row) + { + LocalTensor localGradient = this->inQueGradient.template AllocTensor(); + uint32_t gradientDataLen = row * this->dim2 * sizeof(T); + // 连续传输数据块个数;len:连续传输数据块长度,Byte,非对齐搬运;0, 0, 0:源/目标数据块间隔,保留字段 + DataCopyExtParams gradientParams{1, gradientDataLen, 0, 0, 0}; + // 搬运填充参数 + DataCopyPadExtParams gradientPadParams{true, 0, 2, 0}; + DataCopyPad(localGradient, this->gmGradient[progress * this->row * this->dim2], gradientParams, + gradientPadParams); + + LocalTensor localIndices = this->inQueIndices.template AllocTensor(); + uint32_t indicesDataLen = row * sizeof(int32_t); + DataCopyExtParams indicesParams{1, indicesDataLen, 0, 0, 0}; + DataCopyPadExtParams indicesPadParams{true, 0, 2, 0}; + DataCopyPad(localIndices, this->gmIndices[progress * this->row], indicesParams, indicesPadParams); + + this->inQueGradient.EnQue(localGradient); + this->inQueIndices.EnQue(localIndices); + } + + // 计算函数,完成Compute阶段的处理,被核心Process函数调用 + __aicore__ inline void Compute(int32_t progress, int32_t row) + { + LocalTensor localGradient = this->inQueGradient.template DeQue(); + LocalTensor localIndices = this->inQueIndices.template DeQue(); + Muls(localIndices, localIndices, this->dim2, row); + // 根据 indices 从 inputM 中切分出来 m_slice + LocalTensor localMSlice = this->queMSlice.template AllocTensor(); + LocalTensor localVSlice = this->queVSlice.template AllocTensor(); + LocalTensor localVarSlice = this->queVarSlice.template AllocTensor(); + + pipe_barrier(PIPE_ALL); + + int32_t index = 0; + for (int32_t i = 0; i < row; i++) { + index = localIndices.GetValue(i); + if (index >= 0) { + DataCopy(localMSlice[i * this->dim2], gmInputM[index], this->dim2); + DataCopy(localVSlice[i * this->dim2], gmInputV[index], this->dim2); + DataCopy(localVarSlice[i * this->dim2], gmInputVar[index], this->dim2); + } + } + + this->queMSlice.EnQue(localMSlice); + this->queVSlice.EnQue(localVSlice); + this->queVarSlice.EnQue(localVarSlice); + localMSlice = this->queMSlice.template DeQue(); + localVSlice = this->queVSlice.template DeQue(); + localVarSlice = this->queVarSlice.template DeQue(); + + // 计算M + Muls(localMSlice, localMSlice, this->beta1, row * this->dim2); + Muls(this->updateM, localGradient, (1 - this->beta1), row * this->dim2); + this->updateM = localMSlice + this->updateM; + + // 计算V + Muls(localVSlice, localVSlice, this->beta2, row * this->dim2); + Mul(this->updateV, localGradient, localGradient, row * this->dim2); + Muls(this->updateV, this->updateV, (1 - this->beta2), row * this->dim2); + this->updateV = localVSlice + this->updateV; + + // 计算Var + Abs(this->updateV, this->updateV, row * this->dim2); + Sqrt(this->updateVar, this->updateV, row * this->dim2); + Adds(this->updateVar, this->updateVar, this->epsilon, row * this->dim2); + Muls(this->temp, this->updateM, -this->lr, row * this->dim2); + Div(this->updateVar, this->temp, this->updateVar, row * this->dim2); + Add(this->updateVar, this->updateVar, localVarSlice, row * this->dim2); + + pipe_barrier(PIPE_ALL); + + // 计算结果数据原地更新到输入tensor中 + for (int32_t i = 0; i < row; i++) { + index = localIndices.GetValue(i); + if (index >= 0) { + // __GET_CODE_CHANNEL__宏的作用是防止拷贝操作被识别为matmul而报错 +#ifndef __GET_CODE_CHANNEL__ + DataCopy(this->gmInputM[index], this->updateM[i * this->dim2], this->dim2); + DataCopy(this->gmInputV[index], this->updateV[i * this->dim2], this->dim2); + DataCopy(this->gmInputVar[index], this->updateVar[i * this->dim2], this->dim2); +#endif + } + } + pipe_barrier(PIPE_ALL); + + this->inQueGradient.FreeTensor(localGradient); + this->queMSlice.FreeTensor(localMSlice); + this->queVSlice.FreeTensor(localVSlice); + this->queVarSlice.FreeTensor(localVarSlice); + this->inQueIndices.FreeTensor(localIndices); + } + +private: + float lr, beta1, beta2, epsilon; + int32_t dim0, dim1, dim2, row, batch, loopCount, rowLeft, loopCountTail, rowLeftTail, coreNum; + LocalTensor updateM, updateV, updateVar, temp; + LocalTensor localIndices; + GlobalTensor gmGradient, gmInputM, gmInputV, gmInputVar; + GlobalTensor gmIndices; + GlobalTensor gmLearningRate; + TPipe pipe; + TQue inQueGradient, inQueIndices; + TQue queMSlice, queVSlice, queVarSlice; + TBuf calcBufM; + TBuf calcBufV; + TBuf calcBufVar; + TBuf calcBuf; +}; + +extern "C" __global__ __aicore__ void lazy_adam(GM_ADDR gradient, GM_ADDR indices, GM_ADDR inputM, GM_ADDR inputV, + GM_ADDR inputVar, GM_ADDR lr, GM_ADDR inputMRef, GM_ADDR inputVRef, + GM_ADDR inputVarRef, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tiling_data, tiling); + LazyAdam op32; + op32.Init(gradient, indices, inputM, inputV, inputVar, lr, inputMRef, inputVRef, inputVarRef, tiling_data.beta1, + tiling_data.beta2, tiling_data.epsilon, tiling_data.dim0, tiling_data.dim1, tiling_data.dim2, + tiling_data.row, tiling_data.indicesAllocSize, tiling_data.otherAllocSize, tiling_data.batch, + tiling_data.loopCount, tiling_data.rowLeft, tiling_data.loopCountTail, tiling_data.rowLeftTail, + tiling_data.coreNum); +#ifdef KERNEL_TASK_TYPE_DEFAULT + // Set kernel type with new versions of CANN to avoid matmul error during compiling. + // In previous versions of CANN, avoid matmul error by using '#ifndef __GET_CODE_CHANNEL__'. + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY); +#endif + op32.Process(); +} \ No newline at end of file diff --git a/cust_op/fused_lazy_adam/run.sh b/cust_op/fused_lazy_adam/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..63bf7af48fc711d0024b37f6f5f3288c9c1505d3 --- /dev/null +++ b/cust_op/fused_lazy_adam/run.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +source /etc/profile + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +# 利用msopgen生成可编译文件 +rm -rf ./lazy_adam +msopgen gen -i lazy_adam.json -f tf -c ai_core-Ascend910B1 -lan cpp -out ./lazy_adam -m 0 -op LazyAdam + +cp -rf op_kernel lazy_adam/ +cp -rf op_host lazy_adam/ + +cd lazy_adam + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"mxrec_fused_lazy_adam":g' CMakePresets.json + +bash build.sh + +# 安装编译成功的算子包 +bash ./build_out/custom_opp*.run + +cd .. +rm -rf ./lazy_adam diff --git "a/docs/MindX 6.0.RC1 \351\200\232\344\277\241\347\237\251\351\230\265.xlsx" "b/docs/MindX 6.0.RC1 \351\200\232\344\277\241\347\237\251\351\230\265.xlsx" new file mode 100644 index 0000000000000000000000000000000000000000..9e14cd61ae280c9144d8def5fec2024625260f48 Binary files /dev/null and "b/docs/MindX 6.0.RC1 \351\200\232\344\277\241\347\237\251\351\230\265.xlsx" differ diff --git "a/docs/MindX\342\200\242SDK\342\200\2426.0.RC1\342\200\242mxRec\342\200\242\345\205\254\347\275\221\345\234\260\345\235\200\345\222\214\351\202\256\347\256\261\345\234\260\345\235\200.xlsx" "b/docs/MindX\342\200\242SDK\342\200\2426.0.RC1\342\200\242mxRec\342\200\242\345\205\254\347\275\221\345\234\260\345\235\200\345\222\214\351\202\256\347\256\261\345\234\260\345\235\200.xlsx" new file mode 100644 index 0000000000000000000000000000000000000000..de085d900879882e548b9fa5cd539bb2f698cc0f Binary files /dev/null and "b/docs/MindX\342\200\242SDK\342\200\2426.0.RC1\342\200\242mxRec\342\200\242\345\205\254\347\275\221\345\234\260\345\235\200\345\222\214\351\202\256\347\256\261\345\234\260\345\235\200.xlsx" differ diff --git a/docs/build_mxRec_images/centos_build/Dockerfile b/docs/build_mxRec_images/centos_build/Dockerfile index ee1d98e8a24176391ece7bd610ed9121c67aab30..190ec21bc268b603dbf776256ac6a9d1f6f91593 100644 --- a/docs/build_mxRec_images/centos_build/Dockerfile +++ b/docs/build_mxRec_images/centos_build/Dockerfile @@ -114,6 +114,7 @@ RUN pip3.7 install -U pip && \ pip3.7 install cffi==1.12.3 && \ pip3.7 install pyyaml && \ pip3.7 install pathlib2 && \ + pip3.7 install pandas && \ pip3.7 install grpcio && \ pip3.7 install grpcio-tools && \ pip3.7 install protobuf==3.20.0 && \ @@ -130,6 +131,9 @@ RUN pip3.7 install -U pip && \ pip3.7 install h5py==3.1.0 && \ rm -rf /root/.cache/pip +# 安装mpi4py时使用该环境变量,安装完成后取消 +RUN unset CC + # 10.设置驱动路径环境变量 ARG ASCEND_BASE=/usr/local/Ascend ENV LD_LIBRARY_PATH=$ASCEND_BASE/driver/lib64:$ASCEND_BASE/driver/lib64/common:$ASCEND_BASE/driver/lib64/driver:$LD_LIBRARY_PATH diff --git a/examples/DCNv2/README.md b/examples/DCNv2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f1940ebeff699260b1a2b071d194df219bfbf637 --- /dev/null +++ b/examples/DCNv2/README.md @@ -0,0 +1,54 @@ +# DCNv2模型运行说明 + +## 代码结构 +```shell +. +├── config.py # 模型配置文件 +├── delay_loss_scale.py # loss缩放函数 +├── main_mxrec.py # 主函数 +├── model.py # DCNv2模型 +├── op_impl_mode.ini # 算子执行模式配置 +├── optimizer.py # 优化器 +├── README.md # DCNv2模型运行说明 +└── run.sh # 运行DCNv2模型的脚本 +``` + +## 1.准备数据 +参考DLRM模型中criteo_tb目录下的说明文档准备好模型所需要的数据集,放在一个目录下,比如:/data/criteo_tb/。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行DCNv2模型 +执行完以上步骤之后,接下来就可以运行DCNv2模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入5个参数,分别对应:so_path、mx_rec_package_path、hccl_cfg_json、 +dlrm_criteo_data_path和ip。运行命令如: +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {ip} +``` +* so_path:so_path是mxRec中动态库的目录,一般在mxRec的安装目录下的libasc目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc。 +* mx_rec_package_path:mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +* hccl_cfg_json:hccl_cfg_json是hccl通信配置文件,如果配置了ip参数,这个参数就不用了,直接给一个""空字符串即可。 +* dlrm_criteo_data_path:dlrm_criteo_data_path是数据集所在的目录,比如/data/criteo_tb/。 +* ip:ip是运行模型的机器所在的ip,建议配置。 diff --git a/examples/DCNv2/config.py b/examples/DCNv2/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7c9bce084a62d1dd3ea8a381e9ae55c6d9fe7c --- /dev/null +++ b/examples/DCNv2/config.py @@ -0,0 +1,208 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import os + +import tensorflow as tf +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig + +from mx_rec.constants.constants import CacheModeEnum + +SSD_DATA_PATH = ["ssd_data"] + + +class LearningRateScheduler: + """ + LR Scheduler combining Polynomial Decay with Warmup at the beginning. + TF-based cond operations necessary for performance in graph mode. + """ + + def __init__(self, base_lr_dense, base_lr_sparse, warmup_steps, decay_start_step, decay_steps): + self.warmup_steps = tf.constant(warmup_steps, dtype=tf.int32) + self.decay_start_step = tf.constant(decay_start_step, dtype=tf.int32) + self.decay_steps = tf.constant(decay_steps) + self.decay_end_step = decay_start_step + decay_steps # 65041 + self.poly_power = 2.0 + self.base_lr_dense = base_lr_dense + self.base_lr_sparse = base_lr_sparse + + def calc(self, global_step): + # used for the warmup stage + warmup_step = tf.cast(1 / self.warmup_steps, tf.float32) + lr_factor_warmup = 1 - tf.cast(self.warmup_steps - global_step, tf.float32) * warmup_step + lr_factor_warmup = tf.cast(lr_factor_warmup, tf.float32) + # used for the constant stage + lr_factor_constant = tf.cast(1.0, tf.float32) + + # used for the decay stage + lr_factor_decay = (self.decay_end_step - global_step) / self.decay_steps + lr_factor_decay = tf.math.pow(lr_factor_decay, self.poly_power) + lr_factor_decay = tf.cast(lr_factor_decay, tf.float32) + sparse_after_decay = tf.cast(1 / self.decay_steps, tf.float32) + + lr_factor_decay_sparse = tf.cond( + global_step < self.decay_end_step, + lambda: lr_factor_decay, + lambda: sparse_after_decay, + ) + + lr_factor_decay_dense = tf.cond( + global_step < self.decay_end_step, + lambda: lr_factor_decay, + lambda: sparse_after_decay, + ) + + poly_schedule_sparse = tf.cond( + global_step < self.decay_start_step, + lambda: lr_factor_constant, + lambda: lr_factor_decay_sparse, + ) + + poly_schedule_dense = tf.cond( + global_step < self.decay_start_step, + lambda: lr_factor_constant, + lambda: lr_factor_decay_dense, + ) + + lr_factor_sparse = tf.cond( + global_step < self.warmup_steps, lambda: lr_factor_warmup, lambda: poly_schedule_sparse + ) + + lr_factor_dense = tf.cond( + global_step < self.warmup_steps, lambda: lr_factor_warmup, lambda: poly_schedule_dense + ) + + lr_sparse = self.base_lr_sparse * lr_factor_sparse + lr_dense = self.base_lr_dense * lr_factor_dense + return lr_dense, lr_sparse + + +class Config: + def __init__(self, ): + self.rank_id = int(os.getenv("OMPI_COMM_WORLD_RANK")) if os.getenv("OMPI_COMM_WORLD_RANK") else None + tmp = os.getenv("TRAIN_RANK_SIZE") + if tmp is None: + raise ValueError("please export TRAIN_RANK_SIZE") + self.rank_size = int(tmp) + + self.data_path = os.getenv("DLRM_CRITEO_DATA_PATH") + self.train_file_pattern = "train" + self.test_file_pattern = "test" + + self.batch_size = 8192 + self.line_per_sample = 1024 + self.train_epoch = 3 + self.test_epoch = 1 + self.perform_shuffle = False + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.int64 + + self.feat_cnt = 26 + self.__set_emb_table_size() + + self.field_num = 26 + self.send_count = 46000 // self.rank_size + + self.emb_dim = 128 + self.hashtable_threshold = 1 + + self.USE_PIPELINE_TEST = False + + # 动态学习率 + GLOBAL_BATCH_SIZE = 8192 * 8 + LR_SCHEDULE_STEPS = [ + int(2750 * 55296 / GLOBAL_BATCH_SIZE), + int(49315 * 55296 / GLOBAL_BATCH_SIZE), + int(27772 * 55296 / GLOBAL_BATCH_SIZE), + ] + self.global_step = tf.Variable(0, trainable=False) + _lr_scheduler = LearningRateScheduler( + 28.443, + 33.71193, + LR_SCHEDULE_STEPS[0], + LR_SCHEDULE_STEPS[1], + LR_SCHEDULE_STEPS[2], + ) + self.learning_rate = _lr_scheduler.calc(self.global_step) + + def __set_emb_table_size(self): + self.cache_mode = os.getenv("CACHE_MODE") + if self.cache_mode is None: + raise ValueError("please export CACHE_MODE environment variable, support:[HBM, DDR, SSD]") + + if self.cache_mode == CacheModeEnum.HBM.value: + self.dev_vocab_size = 24_000_000 * self.rank_size + self.host_vocab_size = 0 + elif self.cache_mode == CacheModeEnum.DDR.value: + self.dev_vocab_size = 500_000 * self.rank_size + self.host_vocab_size = 24_000_000 * self.rank_size + elif self.cache_mode == CacheModeEnum.SSD.value: + self.dev_vocab_size = 100_000 * self.rank_size + self.host_vocab_size = 2_000_000 * self.rank_size + self.ssd_vocab_size = 24_000_000 * self.rank_size + else: + raise ValueError(f"get CACHE_MODE:{self.cache_mode}, expect in [HBM, DDR, SSD]") + + def get_emb_table_cfg(self): + if self.cache_mode == CacheModeEnum.HBM.value: + return {"device_vocabulary_size": self.dev_vocab_size} + elif self.cache_mode == CacheModeEnum.DDR.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size} + elif self.cache_mode == CacheModeEnum.SSD.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size, + "ssd_vocabulary_size": self.ssd_vocab_size, + "ssd_data_path": SSD_DATA_PATH} + else: + raise RuntimeError(f"get CACHE_MODE:{self.cache_mode}, check Config.__set_emb_table_size implementation") + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + # 可选配置level0:pairwise;level1:pairwise + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:fullmesh;level1:fullmesh") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 10 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + custom_op.parameter_map["op_execute_timeout"].i = 2000 + custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes( + str(13 * 1024 * 1024 * 1024)) # total 31 need 13; + custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(18 * 1024 * 1024 * 1024)) # need 25 + custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("DNN_VM_AICPU:3,AIcoreEngine:3") + + if dump_data: + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + return session_config diff --git a/examples/DCNv2/delay_loss_scale.py b/examples/DCNv2/delay_loss_scale.py index a9ee5e645eeb70d7d3b0d310f9405c569919a38d..821b221090523c577261578e3950153fb4429fa3 100644 --- a/examples/DCNv2/delay_loss_scale.py +++ b/examples/DCNv2/delay_loss_scale.py @@ -21,13 +21,13 @@ from tensorflow.compat.v1.train import Optimizer class DenseLossScaleOptimizer: def __init__(self, opt, loss_scale): if not isinstance(opt, Optimizer): - raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) + raise ValueError("`opt` must be an instance of Optimizer, but got: %s" % type(opt)) self._optimizer = opt self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) - self._optimizer._lr = self._optimizer._lr / self._loss_scale + _scale_learning_rate(self._optimizer, loss_scale) def compute_gradients(self, loss, var_list=None): - return self._optimizer.compute_gradients(loss*self._loss_scale, var_list=var_list) + return self._optimizer.compute_gradients(loss * self._loss_scale, var_list=var_list) def apply_gradients(self, avg_grads): return self._optimizer.apply_gradients(avg_grads) @@ -36,13 +36,26 @@ class DenseLossScaleOptimizer: class SparseLossScaleOptimizer: def __init__(self, opt, loss_scale): if not isinstance(opt, Optimizer): - raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) + raise ValueError("`opt` must be an instance of Optimizer, but got: %s" % type(opt)) self._optimizer = opt self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) - self._optimizer._lr = self._optimizer._lr / self._loss_scale + _scale_learning_rate(self._optimizer, loss_scale) def compute_gradients(self, loss, var_list=None): - return tf.gradients(loss*self._loss_scale, var_list) + return tf.gradients(loss * self._loss_scale, var_list) def apply_gradients(self, grads_and_vars): - return self._optimizer.apply_gradients(grads_and_vars) \ No newline at end of file + return self._optimizer.apply_gradients(grads_and_vars) + + +def _scale_learning_rate(opt: Optimizer, loss_scale: float) -> None: + if loss_scale == 0: + raise ValueError("`loss_scale` can not be zero") + if hasattr(opt, "_learning_rate"): + # `SGD` or `Adagrad` + opt._learning_rate = opt._learning_rate / tf.convert_to_tensor(loss_scale, tf.float32) + elif hasattr(opt, "_lr"): + # `Adam` + opt._lr = opt._lr / tf.convert_to_tensor(loss_scale, tf.float32) + else: + raise ValueError("`opt` should have a `_learning_rate` or `_lr` named field") diff --git a/examples/DCNv2/main_mxrec.py b/examples/DCNv2/main_mxrec.py index 540445e8b6fcef77a6790b5de7e927841950a182..a04e1c475533bd9c3a02a55e12ed81feef1d67a0 100644 --- a/examples/DCNv2/main_mxrec.py +++ b/examples/DCNv2/main_mxrec.py @@ -14,17 +14,20 @@ # limitations under the License. # ============================================================================== +import os +import random +import shutil import time import warnings -import random from glob import glob - from sklearn.metrics import roc_auc_score + import numpy as np + from npu_bridge.npu_init import * from model import MyModel -from dlrm.model.config import sess_config, Config +from config import sess_config, Config, SSD_DATA_PATH, CacheModeEnum from optimizer import get_dense_and_sparse_optimizer from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline @@ -53,8 +56,8 @@ def add_timestamp_func(batch): return batch -def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use_faae=False): - if cfg.USE_PIPELINE_TEST: +def make_batch_and_iterator(config, feature_spec_list, is_training, dump_graph, is_use_faae=False): + if config.USE_PIPELINE_TEST: num_parallel = 1 else: num_parallel = 8 @@ -62,9 +65,9 @@ def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use def extract_fn(data_record): features = { # Extract features using the keys set during creation - 'label': tf.compat.v1.FixedLenFeature(shape=(cfg.line_per_sample,), dtype=tf.int64), - 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(26 * cfg.line_per_sample,), dtype=tf.int64), - 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(13 * cfg.line_per_sample,), dtype=tf.float32), + 'label': tf.compat.v1.FixedLenFeature(shape=(config.line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(26 * config.line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(13 * config.line_per_sample,), dtype=tf.float32), } sample = tf.compat.v1.parse_single_example(data_record, features) return sample @@ -77,24 +80,24 @@ def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use return batch if is_training: - files_list = glob(os.path.join(cfg.data_path, cfg.train_file_pattern) + '/*.tfrecord') + files_list = glob(os.path.join(config.data_path, config.train_file_pattern) + '/*.tfrecord') else: - files_list = glob(os.path.join(cfg.data_path, cfg.test_file_pattern) + '/*.tfrecord') + files_list = glob(os.path.join(config.data_path, config.test_file_pattern) + '/*.tfrecord') dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel) - batch_size = cfg.batch_size // cfg.line_per_sample + batch_size = config.batch_size // config.line_per_sample - dataset = dataset.shard(cfg.rank_size, cfg.rank_id) + dataset = dataset.shard(config.rank_size, config.rank_id) if is_training: dataset = dataset.shuffle(batch_size * 1000, seed=SHUFFLE_SEED) if is_training: - dataset = dataset.repeat(cfg.train_epoch) + dataset = dataset.repeat(config.train_epoch) else: - dataset = dataset.repeat(cfg.test_epoch) + dataset = dataset.repeat(config.test_epoch) dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel).batch(batch_size, drop_remainder=True) dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) - if use_faae: + if is_use_faae: dataset = dataset.map(add_timestamp_func) if not MODIFY_GRAPH_FLAG: @@ -125,7 +128,7 @@ def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph): elif len(embedding_list) > 1: emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False) else: - raise ValueError("The length of embedding_list must be greater than or equal to 1.") + raise ValueError("the length of embedding_list must be greater than or equal to 1.") my_model = MyModel() model_output = my_model.build_model(embedding=emb, dense_feature=batch["dense_feature"], @@ -155,7 +158,7 @@ def evaluate(): try: eval_current_steps += 1 eval_start = time.time() - eval_loss, pred, label = sess.run([eval_model["loss"], eval_model["pred"], eval_label]) + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_label]) eval_cost = time.time() - eval_start eval_qps = (1 / eval_cost) * rank_size * cfg.batch_size log_loss_list += list(eval_loss.reshape(-1)) @@ -186,7 +189,7 @@ def evaluate_fix(step): while not finished: try: eval_current_steps += 1 - eval_loss, pred, label = sess.run([eval_model["loss"], eval_model["pred"], eval_model["label"]]) + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_model.get("label")]) log_loss_list += list(eval_loss.reshape(-1)) pred_list += list(pred.reshape(-1)) label_list += list(label.reshape(-1)) @@ -245,12 +248,33 @@ def create_feature_spec_list(use_timestamp=False): return feature_spec_list +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + if os.getenv("CACHE_MODE", "") != CacheModeEnum.SSD.value: + return + logger.info("Current cache mode is SSD, and file overwrite is not allowed in SSD mode, deleting exist directory" + " then create empty directory for this use case.") + for sub_path in SSD_DATA_PATH: + _del_related_dir(sub_path) + os.makedirs(sub_path, mode=0o550, exist_ok=True) + logger.info(f"Create dir:{sub_path}") + + if __name__ == "__main__": tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings("ignore") + _clear_saved_model() - rank_id = int(os.getenv("RANK_ID")) if os.getenv("RANK_ID") else None - rank_size = int(os.getenv("RANK_SIZE")) if os.getenv("RANK_SIZE") else None + rank_size = int(os.getenv("TRAIN_RANK_SIZE")) if os.getenv("TRAIN_RANK_SIZE") else None interval = int(os.getenv("INTERVAL")) if os.getenv("INTERVAL") else None train_steps = 10000 eval_steps = 1360 @@ -261,8 +285,8 @@ if __name__ == "__main__": MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) use_faae = bool(int(os.getenv("USE_FAAE", 0))) except ValueError as err: - raise ValueError(f"please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " - f"or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err + raise ValueError("please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " + "or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) logger.info(f"USE_DYNAMIC: {use_dynamic}") @@ -270,7 +294,7 @@ if __name__ == "__main__": use_dynamic=use_dynamic, use_dynamic_expansion=use_dynamic_expansion) IF_LOAD = False rank_id = mxrec_util.communication.hccl_ops.get_rank_id() - filelist = glob(f"./saved-model/sparse-model-0") + filelist = glob("./saved-model/sparse-model-0") if filelist: IF_LOAD = True ConfigInitializer.get_instance().if_load = IF_LOAD @@ -286,16 +310,15 @@ if __name__ == "__main__": feature_spec_list_eval = create_feature_spec_list(use_timestamp=False) train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True, - dump_graph=True, use_faae=use_faae) + dump_graph=True, is_use_faae=use_faae) eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False, - dump_graph=False, use_faae=use_faae) + dump_graph=False, is_use_faae=use_faae) logger.info(f"train_batch: {train_batch}") if use_faae: cfg.dev_vocab_size = cfg.dev_vocab_size // 2 optimizer_list = [get_dense_and_sparse_optimizer(cfg)] - sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] # note: variance_scaling_initializer only support HBM mode emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=SPARSE_HASHTABLE_SEED) \ @@ -306,7 +329,6 @@ if __name__ == "__main__": dim=tf.TensorShape([cfg.emb_dim]), name="sparse_embeddings", emb_initializer=emb_initializer, - optimizer_list=[sparse_optimizer_list[0]._optimizer], **cfg.get_emb_table_cfg() ) if use_faae: @@ -323,7 +345,7 @@ if __name__ == "__main__": rank_size = mxrec_util.communication.hccl_ops.get_rank_size() train_ops = [] # multi task training - for loss, (dense_optimizer, sparse_optimizer) in zip([train_model["loss"]], optimizer_list): + for loss, (dense_optimizer, sparse_optimizer) in zip([train_model.get("loss")], optimizer_list): # do dense optimization grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) avg_grads = [] @@ -336,9 +358,9 @@ if __name__ == "__main__": train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) # do sparse optimization by addr sparse_grads = sparse_optimizer.compute_gradients(loss, train_emb_list) # local_embedding @@ -405,11 +427,11 @@ if __name__ == "__main__": start_time = time.time() try: - grad, loss = sess.run([train_ops, train_model["loss"]]) + grad, loss = sess.run([train_ops, train_model.get("loss")]) lr = sess.run(cfg.learning_rate) global_step = sess.run(cfg.global_step) except tf.errors.OutOfRangeError: - logger.info(f"Encounter the end of Sequence for training.") + logger.info("Encounter the end of Sequence for training.") break end_time = time.time() diff --git a/examples/DCNv2/run.sh b/examples/DCNv2/run.sh index f30e0ac60114a5c642fb48feae07e06cd61653e3..860ff53f96a67b1ad5b168b1740f9c2adc987e6c 100644 --- a/examples/DCNv2/run.sh +++ b/examples/DCNv2/run.sh @@ -75,8 +75,6 @@ RANK_ID_START=0 export MXREC_MODE="ASC" echo "MXREC_MODE is $MXREC_MODE" -export USE_MPI=1 -echo "USE_MPI is $USE_MPI" export py=main_mxrec.py echo "py is $py" @@ -94,7 +92,6 @@ if [ -n "$ip" ]; then echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" echo "CM_WORKER_IP=$CM_WORKER_IP" echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" - echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" else # ranktable echo "Current is ranktable solution, hccl json file:${hccl_cfg_json}" @@ -103,30 +100,9 @@ else export RANK_TABLE_FILE=${hccl_cfg_json} fi -if [ $USE_MPI -eq 0 ]; then - echo "use for loop to start tasks" - for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); - do - #设置环境变量,不需要修改 - echo "Device ID: $RANK_ID" - export RANK_ID=$RANK_ID - export ASCEND_DEVICE_ID=$RANK_ID - ASCEND_DEVICE_ID=$RANK_ID - if [ -d $cur_path/output/${ASCEND_DEVICE_ID} ];then - rm -rf $cur_path/output/${ASCEND_DEVICE_ID} - mkdir -p $cur_path/output/${ASCEND_DEVICE_ID} - else - mkdir -p $cur_path/output/${ASCEND_DEVICE_ID} - fi - nohup python3 ${py} > $cur_path/output/$ASCEND_DEVICE_ID/test_$ASCEND_DEVICE_ID.log 2>&1 & - done -else - echo "use horovod to start tasks" - # GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO - mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' - - horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ - python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${num_process}p_$(date +%Y%m%d_%H%M%S).log -fi - +echo "use horovod to start tasks" +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${num_process}p_$(date +%Y%m%d_%H%M%S).log diff --git a/examples/WideDeep/README.md b/examples/WideDeep/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f4815cd96ce0613e5b0b575dcd9915c511a7ea27 --- /dev/null +++ b/examples/WideDeep/README.md @@ -0,0 +1,586 @@ +# wide&deep模型 迁移样例(基于DLRM模型框架) + +开源项目在保证原有结构不变的情况下,可采用替换相关API接口的方式将项目由GPU >> NPU >> mxrec。在模型迁移适配过程中可能因兼容性问题而导致模型迁移失败,此处提供另一种模型适配方案。 + +*** +## 开源项目链接 +Commits on Apr 29, 2022, 提交的SHA-1 hash值(提交ID):4bbfb492b872c5a3290a2bce1ed5c160162558a3 +commit的链接: https://github.com/ZiyaoGeng/RecLearn/tree/4bbfb492b872c5a3290a2bce1ed5c160162558a3 +```shell +https://github.com/ZiyaoGeng/RecLearn +``` +*** +## 数据集 + +```shell +Criteo4500w数据集: +https://ailab.criteo.com/ressources/kaggle-display-advertising-challenge-dataset.tar.gz +``` +*** +## 数据集预处理 + +### 解压文件列表 +- train.txt +- test.txt +- readme.txt + +text.txt因缺少label列无法使用,将train.txt数据集切分为10份,train_01.txt~train_09.txt为训练集,train_10.txt为测试集。数据预处理文件:criteo.py。 + +*** +### 数据预处理运行脚本 +```shell +python critro.py --data_path data_path --output_path output_path +``` +参数说明: +- dataset_path: train.txt的路径,如:"D:\dat\train.txt" +- output_path: tfrecord存放路径,如:"D:\dat\tfrecord\ " +*** + +### criteo.py +#### 1. 分割数据集 +调用`criteo.py`文件中的`get_split_file_path(parent_path, dataset_path, sample_num=4600000)`方法将数据集分割,`sample_num=4600000`是每个子数据集的样本数量。返回包含全部子数据集名称的列表。 + +```python +# get txt_list +file_split_list = get_split_file_path(dataset_path=data_path) +``` +*** +#### 2. 建立特征映射 +调用`criteo.py`文件中的`get_fea_map()`方法,以`{'C1':{}, 'C2':{},..., 'I1':{},...}`形式储存dense_feature的最大最小值以及sparse_feature去重后的特征映射。 + +```python +# get feature_map +feature_map = get_fea_map(split_file_list=file_split_list) +``` +*** +#### 3. dense_feature分桶离散化 +调用`criteo.py`文件中的`rec_kbins_discretizer(data_df, n_bins, min_max_dict)`方法将dense_feature分桶化离散化,`nbins=1000`。 + +```python +# dense feature: Bin continuous data into intervals. +data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, feature_map) +``` +*** +#### 4. sparse_feature特征映射 +通过如下操作将原始的字符串数据映射为0~max的int64数据。 + +```python +# sparse feature: mapping +for col in sparse_features: + try: + data_df[col] = data_df[col].map(lambda x: feature_map[col][x]) + except KeyError as e: + raise KeyError("Feature {} not found in dataset".format(col)) from e +``` +*** +#### 5. 39个特征增加偏移项 +开源项目deep部分对39个特征分别作了embedding,即建了39个表。本项目只建了一张表,因此需要对每个特征对应的值作偏移。`slot_size_array`中的值分别对应各特征去重后的类别数。 + +```python +# add offsets +slot_size_array = [ + 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, + 1462, 585, 10131228, 2202609, 307, 25, 12519, 635, 5, 93147, 5685, 8351594, 3196, + 29, 14994, 5461307, 12, 5654, 2174, 5, 7046548, 19, 17, 286182, 106, 142573 +] +offset_size_list = np.cumsum([0] + slot_size_array[:-1]) +for col_index in range(1, len(offset_size_list) + 1): + data_df.iloc[:, col_index] += offset_size_list[col_index - 1] +``` +*** +#### 6. 数据集格式转换:txt >> tfrecord +调用`criteo.py`文件中的`convert_input2tfrd(in_file_path, out_file_path)`方法将txt文件转换为tfrecord文件。 + +```python +# txt to tfrecords +convert_input2tfrd(in_file_path=file, out_file_path=output_path) +``` +*** + +## 模型运行 + +参考mxrec的`README.md`文件在NPU服务器上配置环境并安装镜像创建容器后,可参考DLRM模型运行命令启动模型训练。模型运行脚本是run.sh,运行此脚本需要四个参数:so_path、mx_rec_package_path、hccl_cfg_json以及dlrm_criteo_data_path。其中, +- so_path: mxrec中libasc所在路径,在镜像中已经安装过mxrec,所以so_path是:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc/ +- mx_rec_package_path: mxrec这个包的安装路径,镜像中是:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/ +- hccl_cfg_json: hccl配置文件所在路径,一般是当前路径下的hccl文件 +- dlrm_criteo_data_path: Wide&Deep模型需要的数据所在路径,根据实际情况进行配置 + +运行mxRec有两种方式,一种是使用hccl配置文件(rank table方案),一种是不使用hccl配置文件(去rank table方案)。 +- 使用hccl配置文件(rank table方案) +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} +``` +*** +- 不使用hccl配置文件(去rank table方案) +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {IP} +``` +如:bash run.sh /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc/ /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/ hccl_json_8p.json /dataset 10.10.10.10。 +**注意:** 去rank table方案,当前路径下不存在hccl文件,模型仍可正常运行。 + + +## 模型结果 +[开源项目](https://github.com/ZiyaoGeng/RecLearn)使用Criteo4500W数据集在GPU上训练模型,结果为`Log Loss=0.4692`、`AUC=0.7930`。适配完成模型后,固定`CACHE_MODE="HBM"`、`USE_FAAE=0`,在`run.sh`中配置其他选项卡,运行结果如下。 + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelOptionsCriteo4500W
Use_DynamicUse_Dynamic_ExpansionUse_Multi_LookupUse_Modify_GraphLog LossAUC
WDL 0 0 0 0 0.4592 0.7934
WDL 0 1 0 0 0.4593 0.7933
WDL 1 0 0 0 0.4594 0.7932
WDL 1 1 0 0 0.4594 0.7932
WDL 1 1 1 0 0.4590 0.7937
WDL 0 0 0 1 0.4593 0.7934
WDL 0 1 0 1 0.4593 0.7933
WDL 1 0 0 1 0.4593 0.7933
WDL 1 1 0 1 0.4594 0.7932
WDL 1 1 1 1 0.4589 0.7937
+ + +*** +## 模型迁移 + +**迁移思路:** 在现有已适配好的dlrm模型框架下,改动相关代码逻辑,完成Wide&deep模型的适配。**核心:根据开源项目model代码修改`model.py`;数据处理操作一部分放入`criteo.py`,一部分放入`main_mxrec.py`中`make_batch_and_iterator()`内;`main_mxrec.py`中其他相关代码改动主要是为了适配mxrec提供的相关特性。** +详细改动见https://gitee.com/ascend/mxrec/pulls/171/commits,Commits ID:7a05b033d41af51df9aed7414ad04216dff821cc。 +下文所提到的`动态扩容`、`动态shape`、`自动改图`、`一表多查`是mxrec提供的相关特性,开关选项见`run.sh`。 + +```shell +# run.sh: 32~37行 +export USE_DYNAMIC=0 # 0:静态shape;1:动态shape +export CACHE_MODE="HBM" # HBM;DDR;SSD +export USE_FAAE=0 # 0:关闭准入淘汰;1:开启准入淘汰 +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +``` + +*** +### DLRM模型框架 +**迁移说明:** 迁移过程中未使用`gradient_descent_w.py`、`mean_auc.py`。 + +- config.py +- delay_loss_scale.py +- gradient_descent_w.py +- main_mxrec.py +- mean_auc.py +- model.py +- optimizer.py +- run.sh + +*** + +### 代码改动说明 +#### 1. config.py +实验超参数配置如下:取消动态学习率逻辑,学习率固定为0.001。 + +```python +# 88~89行 +lr_sparse = self.base_lr_sparse * lr_factor_constant +lr_dense = self.base_lr_dense * lr_factor_constant +# 140~146行 +_lr_scheduler = LearningRateScheduler( + 0.001, + 0.001, + LR_SCHEDULE_STEPS[0], + LR_SCHEDULE_STEPS[1], + LR_SCHEDULE_STEPS[2], +) +# 超参数 +self.batch_size = 4096 +self.line_per_sample = 1 +self.train_epoch = 1 +self.test_epoch = 9 +self.emb_dim = 8 +``` +*** + + +#### 2. model.py +迁移过程中,`model.py`需参考开源项目文件`reclearn/models/ranking/wdl.py`的代码逻辑,使用tensorflow的低阶API重新编写。输出参数必须包括`loss`,`prediction`,`label`,`trainable_variables`。**迁移重点:mxRec对推荐模型中sparse_feature的创表查表操作作了加速,使用`create_table`与`sparse_lookup`接口替换tensorflow中的`tf.nn.embedding_lookup`接口。** 因此在适配开源项目时,会将sparse_feature的embedding操作放在模型结构外。 + +**reclearn开源项目原始代码:** +```python +# wdl.py +import tensorflow as tf +from tensorflow.keras import Model +from tensorflow.keras.layers import Dense, Embedding, Dropout, Input +from tensorflow.keras.regularizers import l2 + +from reclearn.layers import Linear, MLP +from reclearn.layers.utils import index_mapping + +class WideDeep(Model): + def __init__(self, feature_columns, hidden_units, activation='relu', + dnn_dropout=0., embed_reg=0., w_reg=0.): + """Wide&Deep. + Args: + :param feature_columns: A list. [{'feat_name':, 'feat_num':, 'embed_dim':}, ...] + :param hidden_units: A list. Neural network hidden units. + :param activation: A string. Activation function of MLP. + :param dnn_dropout: A scalar. Dropout of MLP. + :param embed_reg: A scalar. The regularization coefficient of embedding. + :param w_reg: A scalar. The regularization coefficient of Linear. + :return + """ + super(WideDeep, self).__init__() + self.feature_columns = feature_columns + self.embed_layers = { + feat['feat_name']: Embedding(input_dim=feat['feat_num'], + input_length=1, + output_dim=feat['embed_dim'], + embeddings_initializer='random_normal', + embeddings_regularizer=l2(embed_reg)) + for feat in self.feature_columns + } + self.map_dict = {} + self.feature_length = 0 + for feat in self.feature_columns: + self.map_dict[feat['feat_name']] = self.feature_length + self.feature_length += feat['feat_num'] + self.dnn_network = MLP(hidden_units, activation, dnn_dropout) + self.linear = Linear(self.feature_length, w_reg=w_reg) + self.final_dense = Dense(1, activation=None) + + def call(self, inputs): + sparse_embed = tf.concat([self.embed_layers[feat_name](value) for feat_name, value in inputs.items()], axis=-1) + x = sparse_embed # (batch_size, field * embed_dim) + # Wide + wide_inputs = index_mapping(inputs, self.map_dict) + wide_inputs = tf.concat([value for _, value in wide_inputs.items()], axis=-1) + wide_out = self.linear(wide_inputs) + # Deep + deep_out = self.dnn_network(x) + deep_out = self.final_dense(deep_out) + # out + outputs = tf.nn.sigmoid(0.5 * wide_out + 0.5 * deep_out) + return outputs + + def summary(self): + inputs = { + feat['feat_name']: Input(shape=(), dtype=tf.int32, name=feat['feat_name']) + for feat in self.feature_columns + } + Model(inputs=inputs, outputs=self.call(inputs)).summary() + +``` +`self.embed_layers`是对数据集中39个特征分别建表作embedding的操作,迁移后对应的代码逻辑见`main_mxrec.py`。 +`self.map_dict`统计了各特征需增加的偏移量。 +`index_mapping`是对数据增加偏移量的操作,迁移后对应的代码逻辑见`criteo.py`。 + +**迁移后代码:** +```python +# model.py +import time +from easydict import EasyDict as edict + +import tensorflow as tf + + +model_cfg = edict() +model_cfg.loss_mode = "batch" +LOSS_OP_NAME = "loss" +LABEL_OP_NAME = "label" +VAR_LIST = "variable" +PRED_OP_NAME = "pred" + + +class MyModel: + def __init__(self): + self.kernel_init = None + self._loss_fn = None + self.is_training = None + + def build_model(self, + wide_embedding=None, + deep_embedding=None, + label=None, + is_training=True, + seed=None, + dropout_rate=None, + batch_norm=False): + + with tf.variable_scope("wide_deep", reuse=tf.AUTO_REUSE): + self._loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True) + self.is_training = is_training + + # wide + batch_size, wide_num, wide_emb_dim = wide_embedding.shape + wide_input = tf.reshape(wide_embedding[:,0], shape=(batch_size, wide_num * 1)) + wide_output = tf.reshape(tf.reduce_sum(wide_input, axis=1), shape=(-1,1)) + + # deep + batch_size, deep_num, deep_emb_dim = deep_embedding.shape + deep_input = tf.reshape(deep_embedding, shape=(batch_size, deep_num * deep_emb_dim)) + + ## MLP + hidden_units = [256,128,64] + net = deep_input + for i,unit in enumerate(hidden_units): + + net = tf.layers.dense(net, units=unit, activation='relu', name=f'hidden_layer_{i}', + kernel_initializer=tf.glorot_uniform_initializer(seed=seed), + bias_initializer=tf.zeros_initializer()) + + if dropout_rate is not None and 0.0 < dropout_rate < 1.0: + net = tf.layers.dropout(net,dropout_rate,training=self.is_training) + if batch_norm: + net = tf.layers.batch_normalization(net, training=self.is_training) + + deep_output = tf.layers.dense(net, units=1, activation=None, name='deep_output', + kernel_initializer=tf.glorot_uniform_initializer(seed=seed), + bias_initializer=tf.zeros_initializer()) + + total_logits = 0.5 * tf.add(wide_output,deep_output,name='total_logits') + loss = self._loss_fn(label, total_logits) + prediction = tf.sigmoid(total_logits) + trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='wide_deep') + return {LOSS_OP_NAME: loss, + PRED_OP_NAME: prediction, + LABEL_OP_NAME: label, + VAR_LIST: trainable_variables} + + +my_model = MyModel() + +``` +*** +#### 3. main_mxrec.py + +`main_mxrec.py`文件中的函数如下所示。`make_batch_and_iterator()`是读取数据集以及对数据作处理的函数;`model_forward()`是前向过程函数;`evaluate()`与`evaluate_fix()`是评估函数,用于计算测试集的AUC与loss。`add_timestamp_func()`与特征准入、淘汰有关;`create_feature_spec_list()`是生成元素为FeatureSpec类的列表的函数,其返回值是`make_batch_and_iterator()`所需的传参。特征准入与淘汰、FeatureSpec类、自动改图等解释见[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0001.html)。 + +- `add_timestamp_func()` +- `make_batch_and_iterator()` +- `model_forward()` +- `evaluate()` +- `evaluate_fix()` +- `create_feature_spec_list()` + +**迁移代码改动说明:** `add_timestamp_func()`、`evaluate()`、`evaluate_fix()`未作修改。 +
+ +3.1 读取数据集:`make_batch_and_iterator()` + +```python +# main_mxrec.py:100~104行 +def map_fn(batch): + new_batch = batch + new_batch['sparse_feature'] = tf.concat([batch['dense_feature'], batch['sparse_feature']], axis=1) + return new_batch +dataset = dataset.map(map_fn, num_parallel_calls=num_parallel) +``` +`map_fn()`:该函数是将分桶后的dense_feature与sparse_feature合并为新sparse_feature。该操作主要与`FeatureSpec()`、`sparse_lookup()`传入参数有关。 + +```python +# main_mxrec.py:109~118行 +if not MODIFY_GRAPH_FLAG: + + # Enable EOSDataset manually. + librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO) + channel_id = 0 if is_training else 1 + # 此处eos_map的调用必须先于insert_func,避免多卡数据不均匀的情况 + dataset = dataset.eos_map(librec, channel_id, kwargs.get("max_train_steps", max_train_steps), + kwargs.get("max_eval_steps", eval_steps)) + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) +``` +`dataset.eos_map()`:该函数主要是为了解决FeatureSpec模式下开`动态shape`选项卡,训练结束无法正常退出的问题。 + +*** +3.2 模型前向传播过程 + +```python +# main_mxrec.py:127~179行 +def model_forward(feature_list, wide_hash_table_list, deep_hash_table_list, batch, is_train, modify_graph, is_use_faae=False): + wide_embedding_list = [] + deep_embedding_list = [] + wide_feature_list = [] + deep_feature_list = [] + if is_use_faae: + feature_list_copy = feature_list[:-1] + else: + feature_list_copy = feature_list + + for i,item in enumerate(feature_list_copy): + if i % 2 == 0: + wide_feature_list.append(item) + else: + deep_feature_list.append(item) + + logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, " + f"wide_hash_table_list: {len(wide_hash_table_list)}, deep_hash_table_list: {len(deep_hash_table_list)}") + + # wide + for wide_feature, wide_hash_table in zip(wide_feature_list, wide_hash_table_list): + if MODIFY_GRAPH_FLAG: + wide_feature = batch["sparse_feature"] + wide_embedding = sparse_lookup(wide_hash_table, wide_feature, cfg.send_count, dim=None, is_train=is_train, + name="wide_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + wide_embedding_list.append(wide_embedding) + + # deep + for deep_feature, deep_hash_table in zip(deep_feature_list, deep_hash_table_list): + if MODIFY_GRAPH_FLAG: + deep_feature = batch["sparse_feature"] + deep_embedding = sparse_lookup(deep_hash_table, deep_feature, cfg.send_count, dim=None, is_train=is_train, + name="deep_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + deep_embedding_list.append(deep_embedding) + + if len(wide_embedding_list) == 1: + wide_emb = wide_embedding_list[0] + deep_emb = deep_embedding_list[0] + elif len(wide_embedding_list) > 1: + wide_emb = tf.reduce_sum(wide_embedding_list, axis=0, keepdims=False) + deep_emb = tf.reduce_sum(deep_embedding_list, axis=0, keepdims=False) + else: + raise ValueError("the length of embedding_list must be greater than or equal to 1.") + my_model = MyModel() + model_output = my_model.build_model(wide_embedding=wide_emb, + deep_embedding=deep_emb, + label=batch["label"], + is_training=is_train, + seed=dense_hashtable_seed, + dropout_rate=0.5) + return model_output +``` +该函数是前向传播函数,主要包括sparse_feature的embedding操作(查表)与model前向操作。130-141行代码是预处理`sparse_lookup`传参的逻辑。147-162行代码对应开源项目中wide部分`self.linear`与deep部分`self.embed_layers`对39个特征作embedding的逻辑。164-171行是配置mxrec中`一表多查`特性的逻辑。 + +*** +3.3 创表操作 + +```python +# main_mxrec.py: 273~296行 +def create_feature_spec_list(use_timestamp=False): + access_threshold = None + eviction_threshold = None + if use_timestamp: + access_threshold = 1000 + eviction_threshold = 180 + + feature_spec_list = [FeatureSpec("sparse_feature", table_name="wide_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold), + FeatureSpec("sparse_feature", table_name="deep_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold)] + + if use_multi_lookup: + feature_spec_list.extend([FeatureSpec("sparse_feature", table_name="wide_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("sparse_feature", table_name="deep_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)]) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + +``` + +```python +# main_mxrec.py: 379~397行 +# 创表操作 +wide_emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed) +deep_emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed) + +sparse_hashtable_wide = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="wide_embeddings", + emb_initializer=wide_emb_initializer, + **cfg.get_emb_table_cfg() +) + +sparse_hashtable_deep = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="deep_embeddings", + emb_initializer=deep_emb_initializer, + **cfg.get_emb_table_cfg() +) +``` +`create_feature_spec_list()`的返回值是`make_batch_and_iterator()`、`model_forward()`的传参;`create_table()`的返回值是`sparse_lookup()`的传参。 +**注意:`len(feature_spec_list)`应与使用`create_table()`接口创建的表数相等;开启`一表多查`选项卡,feature_spec_list中的元素重复添加一次;开启`特征淘汰`选项卡,feature_spec_list增加时间戳的FeatureSpec类元素**。 + +*** + +3.4 模型反向传播过程 +```python +# main_mxrec.py: 410~442行 +train_variables, emb_variables = get_dense_and_sparse_variable() + +rank_size = mxrec_util.communication.hccl_ops.get_rank_size() +train_ops = [] +# multi task training +for loss, (model_optimizer, emb_optimizer) in zip([train_model.get("loss")], optimizer_list): + # do model optimization + grads = model_optimizer.compute_gradients(loss, var_list=train_variables) + avg_grads = [] + for grad, var in grads: + if rank_size > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad / 8.0, var)) + # apply gradients: update variables + train_ops.append(model_optimizer.apply_gradients(avg_grads)) + + if use_dynamic_expansion: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + # do embedding optimization by addr + sparse_grads = emb_optimizer.compute_gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)] + train_ops.append(emb_optimizer.apply_gradients(grads_and_vars)) + else: + # do embedding optimization + sparse_grads = emb_optimizer.compute_gradients(loss, emb_variables) + print("sparse_grads_tensor:", sparse_grads) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, emb_variables)] + train_ops.append(emb_optimizer.apply_gradients(grads_and_vars)) + +# 动态学习率更新 +train_ops.extend([cfg.global_step.assign(cfg.global_step + 1), cfg.learning_rate[0], cfg.learning_rate[1]]) +``` +410-442行代码是模型的反向过程操作。mxRec对推荐模型中sparse_feature的创表查表操作作了加速,使用`create_table`与`sparse_lookup`接口替换tensorflow中的`tf.nn.embedding_lookup`接口。因此模型反向更新分为两部分:417-425行代码是对`model.py`内的模型部分的反向;427-439行代码是对sparse_feature作embedding操作部分的反向过程,根据是否开启`动态扩容`选择不同的参数计算梯度并更新权重。 + +*** + +#### 4. optimizer.py +如上所述,模型反向过程分为`model.py`与`embedding`两部分;`model.py`可使用tf原生的优化器,`embedding`部分选择mxrec提供的`lazy_adam`或`lazy_adam_by_addr`优化器。`delay_loss_scale.py`包装`dense_optimizer`与`sparse_optimizer`并对其应用损失缩放技术,该技术主要作用于混合精度训练过程中。 + +```python +import tensorflow as tf +from delay_loss_scale import DenseLossScaleOptimizer, SparseLossScaleOptimizer +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.train.AdamOptimizer(learning_rate=cfg.learning_rate[0]) + use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate[1]) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1]) + sparse_optimizer = SparseLossScaleOptimizer(sparse_optimizer, 1) + dense_optimizer = DenseLossScaleOptimizer(dense_optimizer, 1) + + return dense_optimizer, sparse_optimizer +``` + + diff --git a/examples/WideDeep/criteo.py b/examples/WideDeep/criteo.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8ea4307edde0aa7807edf4f9925b0e535e208e --- /dev/null +++ b/examples/WideDeep/criteo.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import stat +import pickle +import argparse +import pandas as pd +import numpy as np +import tensorflow as tf +from tqdm import tqdm + +NAMES = ['label'] + [f'I{i}' for i in range(1, 14)] + [f'C{i}' for i in range(1, 27)] + + +def make_sub_file(lines, head, src_name, sub_dir_name, sub): + """Write sub-data. + Args: + :param lines: A list. Several pieces of data. + :param head: A string. ['label', 'I1', 'I2', ...]. + :param src_name: A string. The name of data. + :param sub_dir_name: A string. + :param sub: A scalar(Int). Record the current number of sub file. + :return: sub + 1. + """ + root_path, file_path = os.path.split(src_name) + file_name, suffix = file_path.split('.') + split_file_name = file_name + "_" + str(sub).zfill(2) + "." + suffix + split_file = os.path.join(root_path, sub_dir_name, split_file_name) + if not os.path.exists(os.path.join(root_path, sub_dir_name)): + os.mkdir(os.path.join(root_path, sub_dir_name)) + + modes = stat.S_IWUSR | stat.S_IRUSR + flags = os.O_WRONLY | os.O_TRUNC | os.O_CREAT + f = os.fdopen(os.open(split_file, flags, modes), 'w') + try: + f.writelines([head]) + f.writelines(lines) + return sub + 1 + finally: + f.close() + + +def split_byline_count(filename, count, sub_dir_name): + """Split File. + Note: You can specify how many rows of data each sub file contains. + Args: + :param filename: A string. + :param count: A scalar(int). + :param sub_dir_name: A string. + :return: + """ + f = open(filename, 'r') + try: + head = f.readline() + buf = [] + sub = 1 + for line in f: + buf.append(line) + if len(buf) == count: + sub = make_sub_file(buf, head, filename, sub_dir_name, sub) + buf = [] + if len(buf) != 0: + try: + make_sub_file(buf, head, filename, sub_dir_name, sub) + except FileNotFoundError as err: + raise FileNotFoundError("please check the filename of data") from err + finally: + f.close() + + +def get_split_file_path(parent_path=None, dataset_path=None, sample_num=4600000): + """Get the list of split file path. + Note: Either parent_path or dataset_path must be valid. + If exists dataset_path + "/split", parent_path = dataset_path + "/split". + Args: + :param parent_path: A string. split file's parent path. + :param dataset_path: A string. + :param sample_num: A int. The sample number of every split file. + :return: A list. [file1_path, file2_path, ...] + """ + sub_dir_name = 'split' + if parent_path is None and dataset_path is None: + raise ValueError('Please give parent path or file path.') + if parent_path is None and os.path.exists(os.path.join(os.path.dirname(dataset_path), sub_dir_name)): + parent_path = os.path.join(os.path.dirname(dataset_path), sub_dir_name) + elif parent_path is None or not os.path.exists(parent_path): + split_byline_count(dataset_path, sample_num, sub_dir_name) + parent_path = os.path.join(os.path.dirname(dataset_path), sub_dir_name) + split_file_name = os.listdir(parent_path) + split_file_name.sort() + split_file_list = [parent_path + "/" + file_name for file_name in split_file_name if file_name[-3:] == 'txt'] + return split_file_list + + +def get_fea_map(fea_map_path=None, split_file_list=None): + """Get feature map. + Note: Either parent_path or dataset_path must be valid. + If exists dir(split_file_list[0]) + "/fea_map.pkl", fea_map_path is valid. + If fea_map_path is None and you want to build the feature map, + the default file path is the parent directory of split file + "fea_map.pkl". + Args: + :param fea_map_path: A string. + :param split_file_list: A list. [file1_path, file2_path, ...] + :return: A dict. {'C1':{}, 'C2':{}, ...} + """ + if fea_map_path is None and split_file_list is None: + raise ValueError('Please give feature map path or split file list.') + if fea_map_path is None and split_file_list is not None: + fea_map_path = os.path.join(os.path.dirname(split_file_list[0]), "fea_map.pkl") + if os.path.exists(fea_map_path) and fea_map_path[-3:] == 'pkl': + with open(fea_map_path, 'rb') as f: + fea_map = pickle.load(f) + return fea_map + fea_map = {} + for file_open in tqdm(split_file_list): + f = open(file_open) + for line in f: + row = line.strip('\n').split('\t') + for i in range(14, 40): + if row[i] == '': + continue + name = NAMES[i] + fea_map.setdefault(name, {}) + if fea_map[name].get(row[i]) is None: + fea_map[name][row[i]] = len(fea_map[name]) + for j in range(1, 14): + if row[j] == '': + continue + name = NAMES[j] + fea_map.setdefault(name, {}) + fea_map[name].setdefault('min', float(row[j])) + fea_map[name].setdefault('max', float(row[j])) + fea_map[name]['min'] = min(fea_map[name]['min'], float(row[j])) + fea_map[name]['max'] = max(fea_map[name]['max'], float(row[j])) + f.close() + for i in range(14, 40): + fea_map[NAMES[i]]['-1'] = len(fea_map[NAMES[i]]) + fea_map_path = os.path.join(os.path.dirname(split_file_list[0]), "fea_map.pkl") + + + modes = stat.S_IWUSR | stat.S_IRUSR + flags = os.O_WRONLY | os.O_TRUNC | os.O_CREAT + with os.fdopen(os.open(fea_map_path, flags, modes), 'wb') as fd: + pickle.dump(fea_map, fd, pickle.HIGHEST_PROTOCOL) + + return fea_map + + +def rec_kbins_discretizer(dat, n_bins, min_max_dict): + """Bin continuous data into intervals. + Note: The strategy is "uniform". + Args: + :param dat: A dataframe. + :param n_bins: A scalar(int). + :param min_max_dict: A dict such as {'min': , 'max': }. + :return: The new dataframe. + """ + features = dat.columns + n_features = len(features) + bin_edges = np.zeros(n_features, dtype=object) + for idx, feature in enumerate(features): + bin_edges[idx] = np.linspace(min_max_dict[feature]['min'], min_max_dict[feature]['max'], n_bins + 1) + rtol = 1.e-5 + atol = 1.e-8 + eps = atol + rtol * np.abs(dat[feature]) + dat[feature] = np.digitize(dat[feature] + eps, bin_edges[idx][1:]) + return dat + + +def convert_input2tfrd(in_file_path, out_file_path): + """ + txt to tfrecords + """ + def make_example(label_list, dense_feat_list, sparse_feat_list): + dense_feature = np.array(dense_feat_list, dtype=np.int64).reshape(-1) + sparse_feature = np.array(sparse_feat_list, dtype=np.int64).reshape(-1) + label = np.array(label_list, dtype=np.int64).reshape(-1) + feature_dict = { + "dense_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=dense_feature)), + "sparse_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=sparse_feature)), + "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label)) + } + example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) + + return example + + file_name = out_file_path + in_file_path[-12:-4] + '.tfrecord' + file_writer = tf.io.TFRecordWriter(file_name) + + with open(in_file_path, encoding='utf-8') as file_in: + + for _, line in tqdm(enumerate(file_in)): + + line = line.strip('\n') + items = line.split('\t') + if len(items) != 40: + continue + label = int(items[0]) + dense = items[1:14] + sparse = items[14:] + + ex = make_example(label, dense, sparse) + serialized = ex.SerializeToString() + file_writer.write(serialized) + + file_writer.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Get datasets') + parser.add_argument('--data_path') + parser.add_argument('--output_path') + + args, _ = parser.parse_known_args() + data_path = args.data_path + output_path = args.output_path + + # get txt_list + file_split_list = get_split_file_path(dataset_path=data_path) + # get feature_map + feature_map = get_fea_map(split_file_list=file_split_list) + + for file in tqdm(file_split_list): + + # read data + data_df = pd.read_csv(file, sep='\t', header=None, names=NAMES) + # name feature + sparse_features = ['C' + str(i) for i in range(1, 27)] + dense_features = ['I' + str(i) for i in range(1, 14)] + # data processing + data_df[sparse_features] = data_df[sparse_features].fillna('-1') + data_df[dense_features] = data_df[dense_features].fillna(0) + # sparse feature: mapping + for col in sparse_features: + try: + data_df[col] = data_df[col].map(lambda x: feature_map[col][x]) + except KeyError as e: + raise KeyError("Feature {} not found in dataset".format(col)) from e + # dense feature: Bin continuous data into intervals. + data_df[dense_features] = rec_kbins_discretizer(data_df[dense_features], 1000, feature_map) + # add offsets + slot_size_array = [ + 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, + 1462, 585, 10131228, 2202609, 307, 25, 12519, 635, 5, 93147, 5685, 8351594, 3196, + 29, 14994, 5461307, 12, 5654, 2174, 5, 7046548, 19, 17, 286182, 106, 142573 + ] + offset_size_list = np.cumsum([0] + slot_size_array[:-1]) + for col_index in range(1, len(offset_size_list) + 1): + data_df.iloc[:, col_index] += offset_size_list[col_index - 1] + # save to txt + data_df.to_csv(file, sep='\t', index=False, header=False) + # txt to tfrecords + convert_input2tfrd(in_file_path=file, out_file_path=output_path) + + + + + diff --git a/examples/WideDeep/model/config.py b/examples/WideDeep/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0072dc595d83bf8f5cdae1823445afe50920bc00 --- /dev/null +++ b/examples/WideDeep/model/config.py @@ -0,0 +1,235 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import os + +import tensorflow as tf +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig +from npu_bridge.estimator.npu.npu_config import NPURunConfig + +from mx_rec.constants.constants import CacheModeEnum + +SSD_DATA_PATH = ["ssd_data"] + + +class LearningRateScheduler: + """ + LR Scheduler combining Polynomial Decay with Warmup at the beginning. + TF-based cond operations necessary for performance in graph mode. + """ + + def __init__(self, base_lr_dense, base_lr_sparse, warmup_steps, decay_start_step, decay_steps): + self.warmup_steps = tf.constant(warmup_steps, dtype=tf.int32) + self.decay_start_step = tf.constant(decay_start_step, dtype=tf.int32) + self.decay_steps = tf.constant(decay_steps) + self.decay_end_step = decay_start_step + decay_steps # 65041 + self.poly_power = 2.0 + self.base_lr_dense = base_lr_dense + self.base_lr_sparse = base_lr_sparse + + def calc(self, global_step): + # used for the warmup stage + warmup_step = tf.cast(1 / self.warmup_steps, tf.float32) + lr_factor_warmup = 1 - tf.cast(self.warmup_steps - global_step, tf.float32) * warmup_step + lr_factor_warmup = tf.cast(lr_factor_warmup, tf.float32) + # used for the constant stage + lr_factor_constant = tf.cast(1.0, tf.float32) + + # used for the decay stage + lr_factor_decay = (self.decay_end_step - global_step) / self.decay_steps + lr_factor_decay = tf.math.pow(lr_factor_decay, self.poly_power) + lr_factor_decay = tf.cast(lr_factor_decay, tf.float32) + sparse_after_decay = tf.cast(1 / self.decay_steps, tf.float32) + + lr_factor_decay_sparse = tf.cond( + global_step < self.decay_end_step, + lambda: lr_factor_decay, + lambda: sparse_after_decay, + ) + + lr_factor_decay_dense = tf.cond( + global_step < self.decay_end_step, + lambda: lr_factor_decay, + lambda: sparse_after_decay, + ) + + poly_schedule_sparse = tf.cond( + global_step < self.decay_start_step, + lambda: lr_factor_constant, + lambda: lr_factor_decay_sparse, + ) + + poly_schedule_dense = tf.cond( + global_step < self.decay_start_step, + lambda: lr_factor_constant, + lambda: lr_factor_decay_dense, + ) + + lr_factor_sparse = tf.cond( + global_step < self.warmup_steps, lambda: lr_factor_warmup, lambda: poly_schedule_sparse + ) + + lr_factor_dense = tf.cond( + global_step < self.warmup_steps, lambda: lr_factor_warmup, lambda: poly_schedule_dense + ) + + lr_sparse = self.base_lr_sparse * lr_factor_constant + lr_dense = self.base_lr_dense * lr_factor_constant + return lr_dense, lr_sparse + + +class Config: + def __init__(self, ): + self.rank_id = int(os.getenv("OMPI_COMM_WORLD_RANK")) if os.getenv("OMPI_COMM_WORLD_RANK") else None + tmp = os.getenv("TRAIN_RANK_SIZE") + if tmp is None: + raise ValueError("please export TRAIN_RANK_SIZE") + self.rank_size = int(tmp) + + self.data_path = os.getenv("DLRM_CRITEO_DATA_PATH") + self.train_file_pattern = "train" + self.test_file_pattern = "test" + + self.batch_size = 4096 + self.line_per_sample = 1 + self.train_epoch = 1 + self.test_epoch = 9 + self.perform_shuffle = False + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.int64 + + self.feat_cnt = 26 + self.__set_emb_table_size() + + self.field_num = 26 + self.send_count = 46000 // self.rank_size + + self.emb_dim = 8 + self.hashtable_threshold = 1 + + self.USE_PIPELINE_TEST = False + + # 动态学习率 + GLOBAL_BATCH_SIZE = 8192 * 8 + LR_SCHEDULE_STEPS = [ + int(2750 * 55296 / GLOBAL_BATCH_SIZE), + int(49315 * 55296 / GLOBAL_BATCH_SIZE), + int(27772 * 55296 / GLOBAL_BATCH_SIZE), + ] + self.global_step = tf.Variable(0, trainable=False) + _lr_scheduler = LearningRateScheduler( + 0.001, + 0.001, + LR_SCHEDULE_STEPS[0], + LR_SCHEDULE_STEPS[1], + LR_SCHEDULE_STEPS[2], + ) + self.learning_rate = _lr_scheduler.calc(self.global_step) + + def __set_emb_table_size(self): + self.cache_mode = os.getenv("CACHE_MODE") + if self.cache_mode is None: + raise ValueError("please export CACHE_MODE environment variable, support:[HBM, DDR, SSD]") + + if self.cache_mode == CacheModeEnum.HBM.value: + self.dev_vocab_size = 14_000_000 * self.rank_size + self.host_vocab_size = 0 + elif self.cache_mode == CacheModeEnum.DDR.value: + self.dev_vocab_size = 500_000 * self.rank_size + self.host_vocab_size = 24_000_000 * self.rank_size + elif self.cache_mode == CacheModeEnum.SSD.value: + self.dev_vocab_size = 100_000 * self.rank_size + self.host_vocab_size = 2_000_000 * self.rank_size + self.ssd_vocab_size = 24_000_000 * self.rank_size + else: + raise ValueError(f"get CACHE_MODE:{self.cache_mode}, expect in [HBM, DDR, SSD]") + + def get_emb_table_cfg(self): + if self.cache_mode == CacheModeEnum.HBM.value: + return {"device_vocabulary_size": self.dev_vocab_size} + elif self.cache_mode == CacheModeEnum.DDR.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size} + elif self.cache_mode == CacheModeEnum.SSD.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size, + "ssd_vocabulary_size": self.ssd_vocab_size, + "ssd_data_path": SSD_DATA_PATH} + else: + raise RuntimeError(f"get CACHE_MODE:{self.cache_mode}, check Config.__set_emb_table_size implementation") + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + # 可选配置level0:pairwise;level1:pairwise + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:fullmesh;level1:fullmesh") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 10 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + custom_op.parameter_map["op_execute_timeout"].i = 2000 + custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes( + str(13 * 1024 * 1024 * 1024)) # total 31 need 13; + custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(18 * 1024 * 1024 * 1024)) # need 25 + custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("DNN_VM_AICPU:3,AIcoreEngine:3") + + if dump_data: + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + return session_config + + +def get_npu_run_config(): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + run_config = NPURunConfig( + save_summary_steps=1000, + save_checkpoints_steps=100, + keep_checkpoint_max=5, + session_config=session_config, + log_step_count_steps=20, + precision_mode='allow_mix_precision', + enable_data_pre_proc=True, + iterations_per_loop=1, + jit_compile=False, + op_compiler_cache_mode="enable", + HCCL_algorithm="level0:fullmesh;level1:fullmesh" # 可选配置:level0:pairwise;level1:pairwise + ) + return run_config diff --git a/examples/WideDeep/model/delay_loss_scale.py b/examples/WideDeep/model/delay_loss_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..f73baf68f14e0c9139391cf52995372bd156ff15 --- /dev/null +++ b/examples/WideDeep/model/delay_loss_scale.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from tensorflow.python.training import optimizer + +from config import Config + + +class DenseLossScaleOptimizer: + def __init__(self, opt: optimizer.Optimizer, loss_scale: int) -> None: + if not isinstance(opt, optimizer.Optimizer): + raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) + self._optimizer = opt + self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) + _update_lr_loss_scale(self._optimizer, loss_scale) + + def compute_gradients(self, loss, var_list=None): + return self._optimizer.compute_gradients(loss * self._loss_scale, var_list=var_list) + + def apply_gradients(self, avg_grads): + return self._optimizer.apply_gradients(avg_grads) + + +class SparseLossScaleOptimizer: + def __init__(self, opt: optimizer.Optimizer, loss_scale: int) -> None: + if not isinstance(opt, optimizer.Optimizer): + raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) + self._optimizer = opt + self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) + _update_lr_loss_scale(self._optimizer, loss_scale) + + def compute_gradients(self, loss, var_list=None): + return tf.gradients(loss * self._loss_scale, var_list) + + def apply_gradients(self, grads_and_vars): + return self._optimizer.apply_gradients(grads_and_vars) + + +def _update_lr_loss_scale(opt, loss_scale): + if loss_scale <= 0: + raise RuntimeError("the loss_scale must be greater than zero.") + loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) + if hasattr(opt, "_lr"): + # LazyAdam or Adam optimizer + opt._lr = opt._lr / loss_scale + elif hasattr(opt, "_learning_rate"): + # SGD optimizer + opt._learning_rate = opt._learning_rate / loss_scale + else: + raise RuntimeError("`opt` should have a `_learning_rate` or `_lr` named field.") \ No newline at end of file diff --git a/examples/WideDeep/model/gradient_descent_w.py b/examples/WideDeep/model/gradient_descent_w.py new file mode 100644 index 0000000000000000000000000000000000000000..53adb996bb20424fc91a3722dee7270a45465ddc --- /dev/null +++ b/examples/WideDeep/model/gradient_descent_w.py @@ -0,0 +1,71 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import tensorflow as tf +from tensorflow.python.ops import math_ops +from tensorflow.python.training import gradient_descent +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.log import logger +from mx_rec.util.initialize import ConfigInitializer + + +def create_hash_optimizer(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescent"): + optimizer = CustomizedGradientDescentWithWeighDecay(learning_rate=learning_rate, + weight_decay=weight_decay, + use_locking=use_locking, + name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer + + +class CustomizedGradientDescentWithWeighDecay(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, weight_decay, use_locking=False, name="GradientDescent"): + self.optimizer_type = "gradient_descent_with_weight_decay" + self.weight_decay = weight_decay + super(CustomizedGradientDescentWithWeighDecay, self)._get_name(name=name) + super(CustomizedGradientDescentWithWeighDecay, self).__init__( + learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name + ) + self._slot_num = 0 + self._derivative = 1 + + def get_slot_init_values(self): + logger.info("no slot for gradient descent") + return [] + + def _apply_sparse_duplicate_indices(self, grad, var): + logger.debug(">>>> Enter _apply_sparse_duplicate_indices") + nd_indices = tf.expand_dims(grad.indices, 1) + logger.info(f"weigh_decay={self.weight_decay}") + if self.weight_decay is None: + nd_value = grad.values * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) + else: + nd_value = (grad.values + math_ops.cast(self.weight_decay, var.dtype.base_dtype) * + tf.gather(var, grad.indices)) * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) + var_update_op = tf.scatter_nd_add(var, nd_indices, -nd_value, use_locking=self._use_locking) + return var_update_op + + def _apply_dense(self, grad, var): + logger.debug(">>>> Enter _apply_dense") + raise NotImplementedError("You are using a wrong type of variable.") diff --git a/examples/WideDeep/model/main_mxrec.py b/examples/WideDeep/model/main_mxrec.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7c2f8787871a959b5f59fa372163a95d4f3fdc --- /dev/null +++ b/examples/WideDeep/model/main_mxrec.py @@ -0,0 +1,546 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import shutil +import collections +import time +import warnings +import random +from glob import glob + +import tensorflow as tf +from sklearn.metrics import roc_auc_score +import numpy as np + +from optimizer import get_dense_and_sparse_optimizer +from config import sess_config, Config, SSD_DATA_PATH, CacheModeEnum +from model import MyModel +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET +from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.core.feature_process import EvictHook +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache, GraphModifierHook +from mx_rec.constants.constants import ASCEND_TIMESTAMP, LIBREC_EOS_OPS_SO +from mx_rec.util.initialize import ConfigInitializer, init, terminate_config_initializer +from mx_rec.util.ops import import_host_pipeline_ops +import mx_rec.util as mxrec_util +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.log import logger +from npu_bridge.npu_init import * + +npu_plugin.set_device_sat_mode(0) + +dense_hashtable_seed = 128 +sparse_hashtable_seed = 128 +shuffle_seed = 128 +random.seed(shuffle_seed) + + +def add_timestamp_func(batch): + timestamp = import_host_pipeline_ops().return_timestamp(tf.cast(batch['label'], dtype=tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def make_batch_and_iterator(config, feature_spec_list, is_training, dump_graph, is_use_faae=False, **kwargs): + if config.USE_PIPELINE_TEST: + num_parallel = 1 + else: + num_parallel = 8 + + def extract_fn(data_record): + features = { + # Extract features using the keys set during creation + 'label': tf.compat.v1.FixedLenFeature(shape=(config.line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(26 * config.line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(13 * config.line_per_sample,), dtype=tf.int64), + } + sample = tf.compat.v1.parse_single_example(data_record, features) + return sample + + def reshape_fn(batch): + batch['label'] = tf.reshape(batch['label'], [-1, 1]) + batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 13]) + batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 26]) + return batch + + if is_training: + files_list = glob(os.path.join(config.data_path, config.train_file_pattern) + '/*.tfrecord') + else: + files_list = glob(os.path.join(config.data_path, config.test_file_pattern) + '/*.tfrecord') + dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel) + batch_size = config.batch_size // config.line_per_sample + + dataset = dataset.shard(config.rank_size, config.rank_id) + if is_training: + dataset = dataset.shuffle(batch_size * 1000, seed=shuffle_seed) + dataset = dataset.repeat(config.train_epoch) + else: + dataset = dataset.repeat(config.test_epoch) + dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel).batch(batch_size, + drop_remainder=True) + dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) + + def map_fn(batch): + new_batch = batch + new_batch['sparse_feature'] = tf.concat([batch['dense_feature'], batch['sparse_feature']], axis=1) + return new_batch + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel) + + if is_use_faae: + dataset = dataset.map(add_timestamp_func) + + if not MODIFY_GRAPH_FLAG: + + # Enable EOSDataset manually. + librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO) + channel_id = 0 if is_training else 1 + # 此处eos_map的调用必须先于insert_func,避免多卡数据不均匀的情况 + dataset = dataset.eos_map(librec, channel_id, kwargs.get("max_train_steps", max_train_steps), + kwargs.get("max_eval_steps", eval_steps)) + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) + + dataset = dataset.prefetch(100) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator + + +def model_forward(model_args): + feature_list = model_args.feature_list + wide_hash_table_list = model_args.wide_hash_table_list + deep_hash_table_list = model_args.deep_hash_table_list + batch = model_args.batch + is_train = model_args.is_train + modify_graph = model_args.modify_graph + is_use_faae = model_args.is_use_faae + + wide_embedding_list = [] + deep_embedding_list = [] + wide_feature_list = [] + deep_feature_list = [] + if is_use_faae: + feature_list_copy = feature_list[:-1] + else: + feature_list_copy = feature_list + + for index, item in enumerate(feature_list_copy): + if index % 2 == 0: + wide_feature_list.append(item) + else: + deep_feature_list.append(item) + + logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, " + f"wide_hash_table_list: {len(wide_hash_table_list)}, " + f"deep_hash_table_list: {len(deep_hash_table_list)}") + + # wide + for wide_feature, wide_hash_table in zip(wide_feature_list, wide_hash_table_list): + if MODIFY_GRAPH_FLAG: + wide_feature = batch["sparse_feature"] + wide_embedding = sparse_lookup(wide_hash_table, wide_feature, cfg.send_count, dim=None, is_train=is_train, + name="wide_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + wide_embedding_list.append(wide_embedding) + + # deep + for deep_feature, deep_hash_table in zip(deep_feature_list, deep_hash_table_list): + if MODIFY_GRAPH_FLAG: + deep_feature = batch["sparse_feature"] + deep_embedding = sparse_lookup(deep_hash_table, deep_feature, cfg.send_count, dim=None, is_train=is_train, + name="deep_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + deep_embedding_list.append(deep_embedding) + + if len(wide_embedding_list) == 1: + wide_emb = wide_embedding_list[0] + deep_emb = deep_embedding_list[0] + elif len(wide_embedding_list) > 1: + wide_emb = tf.reduce_sum(wide_embedding_list, axis=0, keepdims=False) + deep_emb = tf.reduce_sum(deep_embedding_list, axis=0, keepdims=False) + else: + raise ValueError("the length of embedding_list must be greater than or equal to 1.") + my_model = MyModel() + + BuildModel = collections.namedtuple("BuildModel", ["wide_embedding", "deep_embedding", "label", "is_training", + "seed", "dropout_rate", "batch_norm"]) + build_model_args = BuildModel(wide_emb, deep_emb, batch["label"], is_train, dense_hashtable_seed, 0.5, False) + model_output = my_model.build_model(build_model_args) + return model_output + + +def evaluate(): + print("read_test dataset") + if not MODIFY_GRAPH_FLAG: + eval_label = eval_model.get("label") + sess.run([eval_iterator.initializer]) + else: + # 在sess run模式下,若还是使用原来batch中的label去sess run,则会出现getnext超时报错,需要使用新数据集中的batch + eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(False).get("label") + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) + log_loss_list = [] + pred_list = [] + label_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + + while not finished: + try: + eval_current_steps += 1 + eval_start = time.time() + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_label]) + eval_cost = time.time() - eval_start + qps_eval = (1 / eval_cost) * rank_size * cfg.batch_size + log_loss_list += list(eval_loss.reshape(-1)) + pred_list += list(pred.reshape(-1)) + label_list += list(label.reshape(-1)) + print(f"eval current_steps: {eval_current_steps}, qps: {qps_eval}") + if eval_current_steps == eval_steps: + finished = True + except tf.errors.OutOfRangeError: + finished = True + auc = roc_auc_score(label_list, pred_list) + mean_log_loss = np.mean(log_loss_list) + return auc, mean_log_loss + + +def evaluate_fix(step): + print("read_test dataset evaluate_fix") + if not MODIFY_GRAPH_FLAG: + sess.run([eval_iterator.initializer]) + else: + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) + log_loss_list = [] + pred_list = [] + label_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + while not finished: + try: + eval_current_steps += 1 + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_model.get("label")]) + log_loss_list += list(eval_loss.reshape(-1)) + pred_list += list(pred.reshape(-1)) + label_list += list(label.reshape(-1)) + print(f"eval current_steps: {eval_current_steps}") + + if eval_current_steps == eval_steps: + finished = True + except tf.errors.OutOfRangeError: + finished = True + + label_numpy = np.array(label_list) + pred_numpy = np.array(pred_list) + if not os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}"): + os.makedirs(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}") + + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy") + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy") + if os.path.exists(f"flag_{rank_id}.txt"): + os.remove(f"flag_{rank_id}.txt") + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy", label_numpy) + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy", pred_numpy) + os.mknod(f"flag_{rank_id}.txt") + while True: + file_exists_list = [os.path.exists(f"flag_{i}.txt") for i in range(rank_size)] + if sum(file_exists_list) == rank_size: + print("All saved!!!!!!!!!!") + break + else: + print("Waitting for saving numpy!!!!!!!!") + time.sleep(1) + continue + + auc = roc_auc_score(label_list, pred_list) + mean_log_loss = np.mean(log_loss_list) + return auc, mean_log_loss + + +def create_feature_spec_list(use_timestamp=False): + access_threshold = None + eviction_threshold = None + if use_timestamp: + access_threshold = 1000 + eviction_threshold = 180 + + feature_spec_list = [ + FeatureSpec("sparse_feature", table_name="wide_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold), + FeatureSpec("sparse_feature", table_name="deep_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold) + ] + + if use_multi_lookup: + feature_spec_list.extend([FeatureSpec("sparse_feature", table_name="wide_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("sparse_feature", table_name="deep_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)]) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"Delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + _del_related_dir("kernel*") + _del_related_dir("model_dir_rank*") + _del_related_dir("op_cache") + + if os.getenv("CACHE_MODE", "") != CacheModeEnum.SSD.value: + return + logger.info("Current cache mode is SSD, and file overwrite is not allowed in SSD mode, deleting exist directory" + " then create empty directory for this use case.") + for sub_path in SSD_DATA_PATH: + _del_related_dir(sub_path) + os.makedirs(sub_path, mode=0o550, exist_ok=True) + logger.info(f"Create dir:{sub_path}") + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + warnings.filterwarnings("ignore") + _clear_saved_model() + + rank_id = int(os.getenv("RANK_ID")) if os.getenv("RANK_ID") else None + rank_size = int(os.getenv("TRAIN_RANK_SIZE")) if os.getenv("TRAIN_RANK_SIZE") else None + interval = int(os.getenv("INTERVAL")) if os.getenv("INTERVAL") else None + max_train_steps = 1270 + train_steps = 1120 + eval_steps = 1080 + + try: + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 0))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + use_faae = bool(int(os.getenv("USE_FAAE", 0))) + except ValueError as err: + raise ValueError("please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " + "or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err + + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + logger.info(f"USE_DYNAMIC:{use_dynamic}") + init(train_steps=train_steps, eval_steps=eval_steps, + use_dynamic=use_dynamic, use_dynamic_expansion=use_dynamic_expansion) + IF_LOAD = False + rank_id = mxrec_util.communication.hccl_ops.get_rank_id() + filelist = glob(f"./saved-model/sparse-model-0") + if filelist: + IF_LOAD = True + ConfigInitializer.get_instance().if_load = IF_LOAD + + cfg = Config() + feature_spec_list_train = None + feature_spec_list_eval = None + if use_faae: + feature_spec_list_train = create_feature_spec_list(use_timestamp=True) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=True) + else: + feature_spec_list_train = create_feature_spec_list(use_timestamp=False) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=False) + + train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True, + dump_graph=True, is_use_faae=use_faae, + max_train_steps=max_train_steps, max_eval_steps=eval_steps) + eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False, + dump_graph=False, is_use_faae=use_faae, + max_train_steps=max_train_steps, max_eval_steps=eval_steps) + logger.info(f"train_batch: {train_batch}") + + if use_faae: + cfg.dev_vocab_size = cfg.dev_vocab_size // 2 + + # 创表操作 + wide_emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed) + deep_emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed) + + sparse_hashtable_wide = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="wide_embeddings", + emb_initializer=wide_emb_initializer, + **cfg.get_emb_table_cfg() + ) + + sparse_hashtable_deep = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="deep_embeddings", + emb_initializer=deep_emb_initializer, + **cfg.get_emb_table_cfg() + ) + + if use_faae: + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, train_batch["timestamp"]) + + # 一表多查 + wide_hashtable_list = [sparse_hashtable_wide, sparse_hashtable_wide] if use_multi_lookup else \ + [sparse_hashtable_wide] + deep_hashtable_list = [sparse_hashtable_deep, sparse_hashtable_deep] if use_multi_lookup else \ + [sparse_hashtable_deep] + + + Forward = collections.namedtuple("Forward", ["feature_list", "wide_hash_table_list", "deep_hash_table_list", + "batch", "is_train", "modify_graph", "is_use_faae"]) + train_forward_args = Forward(feature_spec_list_train, wide_hashtable_list, deep_hashtable_list, train_batch, + True, MODIFY_GRAPH_FLAG, use_faae) + eval_forward_args = Forward(feature_spec_list_eval, wide_hashtable_list, deep_hashtable_list, eval_batch, + False, MODIFY_GRAPH_FLAG, use_faae) + train_model = model_forward(train_forward_args) + eval_model = model_forward(eval_forward_args) + + train_variables, emb_variables = get_dense_and_sparse_variable() + optimizer_list = [get_dense_and_sparse_optimizer(cfg)] + + rank_size = mxrec_util.communication.hccl_ops.get_rank_size() + train_ops = [] + # multi task training + for loss, (model_optimizer, emb_optimizer) in zip([train_model.get("loss")], optimizer_list): + # do model optimization + grads = model_optimizer.compute_gradients(loss, var_list=train_variables) + avg_grads = [] + for grad, var in grads: + if rank_size > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad / 8.0, var)) + # apply gradients: update variables + train_ops.append(model_optimizer.apply_gradients(avg_grads)) + + if use_dynamic_expansion: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + # do embedding optimization by addr + sparse_grads = emb_optimizer.compute_gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)] + train_ops.append(emb_optimizer.apply_gradients(grads_and_vars)) + else: + # do embedding optimization + sparse_grads = emb_optimizer.compute_gradients(loss, emb_variables) + print("sparse_grads_tensor:", sparse_grads) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, emb_variables)] + train_ops.append(emb_optimizer.apply_gradients(grads_and_vars)) + + # 动态学习率更新 + train_ops.extend([cfg.global_step.assign(cfg.global_step + 1), cfg.learning_rate[0], cfg.learning_rate[1]]) + + with tf.control_dependencies(train_ops): + train_ops = tf.no_op() + cfg.learning_rate = [cfg.learning_rate[0], cfg.learning_rate[1]] + + saver = tf.train.Saver() + if MODIFY_GRAPH_FLAG: + modify_graph_and_start_emb_cache(dump_graph=True) + else: + start_asc_pipeline() + + hook_list = [] + if use_faae: + hook_evict = EvictHook(evict_enable=True, evict_time_interval=120) + hook_list.append(hook_evict) + if MODIFY_GRAPH_FLAG: # 该场景添加hook处理校验问题 + hook_list.append(GraphModifierHook(modify_graph=False)) + + # with tf.compat.v1.Session(config=sess_config(dump_data=False)) as sess: + if use_faae: + sess = tf.compat.v1.train.MonitoredTrainingSession( + hooks=hook_list, + config=sess_config(dump_data=False) + ) + sess.graph._unsafe_unfinalize() + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) + else: + sess = tf.compat.v1.Session(config=sess_config(dump_data=False)) + sess.run(tf.compat.v1.global_variables_initializer()) + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) + + epoch = 0 + cost_sum = 0 + qps_sum = 0 + best_auc = 0 + iteration_per_loop = 10 + + train_ops = util.set_iteration_per_loop(sess, train_ops, 10) + + # for i in range(1, TRAIN_STEPS): + i = 0 + while True: + i += 1 + logger.info(f"################ training at step {i * iteration_per_loop} ################") + start_time = time.time() + + try: + grad, loss = sess.run([train_ops, train_model.get("loss")]) + lr = sess.run(cfg.learning_rate) + global_step = sess.run(cfg.global_step) + except tf.errors.OutOfRangeError: + logger.info(f"Encounter the end of Sequence for training.") + break + + end_time = time.time() + cost_time = end_time - start_time + qps = (1 / cost_time) * rank_size * cfg.batch_size * iteration_per_loop + cost_sum += cost_time + logger.info(f"step: {i * iteration_per_loop}; training loss: {loss}") + logger.info(f"step: {i * iteration_per_loop}; grad: {grad}") + logger.info(f"step: {i * iteration_per_loop}; lr: {lr}") + logger.info(f"global step: {global_step}") + logger.info(f"step: {i * iteration_per_loop}; current sess cost time: {cost_time:.10f}; current QPS: {qps}") + logger.info(f"training at step:{i * iteration_per_loop}, " + f"table[{sparse_hashtable_wide.table_name}], " + f"table size:{sparse_hashtable_wide.size()}, table capacity:{sparse_hashtable_wide.capacity()}, " + f"table[{sparse_hashtable_deep.table_name}], " + f"table size:{sparse_hashtable_deep.size()}, table capacity:{sparse_hashtable_deep.capacity()}") + + if i % (train_steps // iteration_per_loop) == 0: + if interval is not None: + test_auc, test_mean_log_loss = evaluate_fix(i * iteration_per_loop) + else: + test_auc, test_mean_log_loss = evaluate() + print("Test auc: {}; log_loss: {} ".format(test_auc, test_mean_log_loss)) + best_auc = max(best_auc, test_auc) + logger.info(f"training step: {i * iteration_per_loop}, best auc: {best_auc}") + + sess.close() + + terminate_config_initializer() + logger.info("Demo done!") diff --git a/examples/WideDeep/model/mean_auc.py b/examples/WideDeep/model/mean_auc.py new file mode 100644 index 0000000000000000000000000000000000000000..ff57df00e575551456883147a5772a3b16dc638f --- /dev/null +++ b/examples/WideDeep/model/mean_auc.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +from glob import glob +import numpy as np + + +def split_auc(log_input): + with open(log_input, 'r') as log: + all_auc = [] + for line in log.readlines(): + if 'Test' in line: + all_auc.append(float(line.split(';')[0].split(':')[-1].strip())) + all_auc_len = len(all_auc) + all_auc_arr = np.array(all_auc)[:all_auc_len - all_auc_len % 8] + test_auc = np.mean(all_auc_arr.reshape(-1, 8), axis=-1) + return test_auc + + +log_path_all = 'latest_*.log' +log_path_list = glob(log_path_all) + +for log_path in log_path_list: + print(os.path.basename(log_path)) + print(split_auc(log_path)) + print('*'*20) \ No newline at end of file diff --git a/examples/WideDeep/model/model.py b/examples/WideDeep/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe2177ec5ff78e9fb149a6505c766087ea56395 --- /dev/null +++ b/examples/WideDeep/model/model.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time +from easydict import EasyDict as edict + +import tensorflow as tf + + +model_cfg = edict() +model_cfg.loss_mode = "batch" +LOSS_OP_NAME = "loss" +LABEL_OP_NAME = "label" +VAR_LIST = "variable" +PRED_OP_NAME = "pred" + + +class MyModel: + def __init__(self): + self.kernel_init = None + self._loss_fn = None + self.is_training = None + + def build_model(self, model_args): + wide_embedding = model_args.wide_embedding + deep_embedding = model_args.deep_embedding + label = model_args.label + is_training = model_args.is_training + seed = model_args.seed + dropout_rate = model_args.dropout_rate + batch_norm = model_args.batch_norm + + with tf.variable_scope("wide_deep", reuse=tf.AUTO_REUSE): + self._loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True) + self.is_training = is_training + + # wide + batch_size, wide_num, wide_emb_dim = wide_embedding.shape + wide_input = tf.reshape(wide_embedding[:, :, 0], shape=(batch_size, wide_num * 1)) + wide_output = tf.reshape(tf.reduce_sum(wide_input, axis=1), shape=(-1, 1)) + + # deep + batch_size, deep_num, deep_emb_dim = deep_embedding.shape + deep_input = tf.reshape(deep_embedding, shape=(batch_size, deep_num * deep_emb_dim)) + + ## MLP + hidden_units = [256, 128, 64] + net = deep_input + for i, unit in enumerate(hidden_units): + + net = tf.layers.dense(net, units=unit, activation='relu', name=f'hidden_layer_{i}', + kernel_initializer=tf.glorot_uniform_initializer(seed=seed), + bias_initializer=tf.zeros_initializer()) + + if dropout_rate is not None and 0.0 < dropout_rate < 1.0: + net = tf.layers.dropout(net, dropout_rate, training=self.is_training) + if batch_norm: + net = tf.layers.batch_normalization(net, training=self.is_training) + + deep_output = tf.layers.dense(net, units=1, activation=None, name='deep_output', + kernel_initializer=tf.glorot_uniform_initializer(seed=seed), + bias_initializer=tf.zeros_initializer()) + + total_logits = 0.5 * tf.add(wide_output, deep_output, name='total_logits') + loss = self._loss_fn(label, total_logits) + prediction = tf.sigmoid(total_logits) + trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='wide_deep') + return {LOSS_OP_NAME: loss, + PRED_OP_NAME: prediction, + LABEL_OP_NAME: label, + VAR_LIST: trainable_variables} + + +my_model = MyModel() diff --git a/examples/WideDeep/model/op_impl_mode.ini b/examples/WideDeep/model/op_impl_mode.ini new file mode 100644 index 0000000000000000000000000000000000000000..579dea433d6ec01f9ad5596646a0a0fce9ce3a20 --- /dev/null +++ b/examples/WideDeep/model/op_impl_mode.ini @@ -0,0 +1 @@ +ScatterNdAdd=support_out_of_bound_index \ No newline at end of file diff --git a/examples/WideDeep/model/optimizer.py b/examples/WideDeep/model/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7685bba414d6a869c74714797fe807c002713b --- /dev/null +++ b/examples/WideDeep/model/optimizer.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from delay_loss_scale import DenseLossScaleOptimizer, SparseLossScaleOptimizer +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.train.AdamOptimizer(learning_rate=cfg.learning_rate[0]) + use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate[1]) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1]) + loss_scale = 1 + sparse_optimizer = SparseLossScaleOptimizer(sparse_optimizer, loss_scale) + dense_optimizer = DenseLossScaleOptimizer(dense_optimizer, loss_scale) + + return dense_optimizer, sparse_optimizer diff --git a/examples/WideDeep/model/run.sh b/examples/WideDeep/model/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c1424435bac57a727cfc1980e25e08c6c5ba74c --- /dev/null +++ b/examples/WideDeep/model/run.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +cur_path=$(dirname "$(readlink -f "$0")") + +so_path=$1 +mx_rec_package_path=$2 +hccl_cfg_json=$3 +dlrm_criteo_data_path=$4 +ip=$5 # no ranktable时传入该参数 + +interface="lo" +num_server=1 +local_rank_size=8 +num_process=$((num_server * local_rank_size)) +export TRAIN_RANK_SIZE=$num_process + +################# 参数配置 ###################### +export USE_DYNAMIC=0 # 0:静态shape;1:动态shape +export CACHE_MODE="HBM" # HBM;DDR;SSD +export USE_FAAE=0 # 0:关闭准入淘汰;1:开启准入淘汰 +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +################################################ +echo "CACHE_MODE:${CACHE_MODE}" + +export HCCL_CONNECT_TIMEOUT=1200 +export DLRM_CRITEO_DATA_PATH=${dlrm_criteo_data_path} +export PYTHONPATH=${mx_rec_package_path}:${so_path}:$PYTHONPATH +export LD_PRELOAD=/usr/lib64/libgomp.so.1 +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +export ASCEND_DEVICE_ID=0 +export RANK_ID_START=0 +export JOB_ID=10086 +export CUSTOMIZED_OPS_LIB_PATH=${so_path}/libcust_ops.so # Todo: please config +export MXREC_LOG_LEVEL="INFO" +export TF_CPP_MIN_LOG_LEVEL=3 +export ASCEND_GLOBAL_LOG_LEVEL=3 +#export USE_FAAE=1 +export ENABLE_FORCE_V2_CONTROL=1 + +export PROFILING_OPTIONS='{"output":"/home/yz/profiling", + "training_trace":"on", + "task_trace":"on", + "aicpu":"on", + "fp_point":"", + "bp_point":"", + "aic_metrics":"PipeUtilization"}' + +RANK_ID_START=0 + +export MXREC_MODE="ASC" +echo "MXREC_MODE is $MXREC_MODE" +export py=main_mxrec.py +echo "py is $py" + +# 区分ranktable和no ranktable +if [ -n "$ip" ]; then + # no ranktable分支 + echo "Current is no ranktable solution." + echo "Input node ip: $ip, please make sure this ip is available." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=60001 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" +else + # ranktable分支 + echo "Current is ranktable solution, hccl json file:${hccl_cfg_json}" + export RANK_SIZE=$num_process + echo "RANK_SIZE=${RANK_SIZE}, please make sure hccl configuration json file match this parameter" + export RANK_TABLE_FILE=${hccl_cfg_json} +fi + +echo "use horovod to start tasks" +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' + +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${num_process}p.log diff --git a/examples/demo/README.md b/examples/demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..931f8c268e0a2f4f0202b0a59576c590c125d582 --- /dev/null +++ b/examples/demo/README.md @@ -0,0 +1,13 @@ +# demo样例说明 + +## 代码结构 +```shell +. +├── little_demo # sess.run模式的demo +├── little_demo_estimator # estimator模式的demo +└── README.md # demo样例说明 +``` + +mxRec提供了一个非常简单的样例模型demo,用于快速体验mxRec。在TensorFlow中,运行模型有sess.run和estimator两种模式。因此,mxRec也提供了两种 +模式下的样例。其中little_demo是sess.run模式的样例;little_demo_estimator是estimator模式的样例。用户可以选择自己需要或者感兴趣的模式进行 +体验,各个模式的样例的说明见对应目录下的README文档。 \ No newline at end of file diff --git a/examples/demo/little_demo/README.md b/examples/demo/little_demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dabe105beebe898820c43ad3c11b0c93ee4f337b --- /dev/null +++ b/examples/demo/little_demo/README.md @@ -0,0 +1,56 @@ +# sess.run模式下demo模型运行说明 + +## 代码结构 +```shell +. +├── config.py # 模型配置文件 +├── dataset.py # 生成数据集的脚本 +├── deterministic_loss # 确定性计算loss样例 +├── main.py # 主函数 +├── model.py # demo模型 +├── op_impl_mode.ini # 算子执行模式配置 +├── optimizer.py # 优化器 +├── random_data_generator.py # 数据生成器 +├── README.md # demo模型运行说明 +├── run_deterministic.sh # 运行确定性计算的脚本 +├── run_mode.py # 执行模型train、evaluate和predict的脚本 +└── run.sh # demo运行脚本 +``` + +## 1.准备数据 +demo样例无需从其他地方下载数据集,在demo样例中mxRec会自动生成数据集,详情见dataset.py和random_data_generator.py。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行demo模型 +执行完以上步骤之后,接下来就可以运行demo模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入ip这个参数,运行命令如: +```shell +bash run.sh main.py {ip} +``` +* ip:ip是运行模型的机器所在的ip。 + +**Tips**:run.sh脚本中有一个参数是mx_rec_package_path,mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +这个参数在脚本是默认的,用户需要根据自己环境中mxRec实际安装的路径进行配置。 \ No newline at end of file diff --git a/examples/demo/little_demo/config.py b/examples/demo/little_demo/config.py index 2cc48216f17b345fc0fd91832afbd28c6409104a..a0912ac511abf447f632aea67f443a9fa189d17e 100644 --- a/examples/demo/little_demo/config.py +++ b/examples/demo/little_demo/config.py @@ -95,7 +95,7 @@ class Config: self.learning_rate = 0.01 -def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2", use_deterministic=0): session_config = tf.compat.v1.ConfigProto(allow_soft_placement=False, log_device_placement=False) @@ -108,7 +108,11 @@ def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:pairwise;level1:pairwise") custom_op.parameter_map["enable_data_pre_proc"].b = True custom_op.parameter_map["iterations_per_loop"].i = 1 - custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + if use_deterministic: + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("must_keep_origin_dtype") + custom_op.parameter_map["deterministic"].i = 1 + else: + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") custom_op.parameter_map["hcom_parallel"].b = False custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") custom_op.parameter_map["op_execute_timeout"].i = 2000 diff --git a/examples/demo/little_demo/deterministic_loss/loss b/examples/demo/little_demo/deterministic_loss/loss new file mode 100644 index 0000000000000000000000000000000000000000..3bd00f807634a6e5c3c4f890d6a8f28301b0be3b --- /dev/null +++ b/examples/demo/little_demo/deterministic_loss/loss @@ -0,0 +1,200 @@ +0.6931473016738892 +0.6930400133132935 +0.6931400895118713 +0.69315505027771 +0.6931849122047424 +0.6931070685386658 +0.6931337714195251 +0.6931014657020569 +0.6931450963020325 +0.6931362152099609 +0.6930745244026184 +0.6931930184364319 +0.693183958530426 +0.6931136846542358 +0.6932246088981628 +0.69315105676651 +0.6931785941123962 +0.6931335926055908 +0.6931543946266174 +0.6931360960006714 +0.6931753158569336 +0.6931651830673218 +0.6931512951850891 +0.6931533217430115 +0.6931378841400146 +0.6931486129760742 +0.6931435465812683 +0.6931432485580444 +0.6930928230285645 +0.6931749582290649 +0.693172037601471 +0.6931487917900085 +0.6931713819503784 +0.6931683421134949 +0.6931532621383667 +0.6931494474411011 +0.6932084560394287 +0.6930452585220337 +0.6931130886077881 +0.6932073831558228 +0.6931206583976746 +0.6931828856468201 +0.6931034922599792 +0.6931605935096741 +0.6931373476982117 +0.6931723952293396 +0.6931106448173523 +0.6931154131889343 +0.6931938529014587 +0.6932826638221741 +0.6932423114776611 +0.6931906342506409 +0.6931505799293518 +0.6931438446044922 +0.6931610107421875 +0.6931508779525757 +0.6931482553482056 +0.693139910697937 +0.693148136138916 +0.6931435465812683 +0.6930944323539734 +0.693130373954773 +0.6931836009025574 +0.6930789947509766 +0.6932032108306885 +0.693130373954773 +0.6933913230895996 +0.6931992173194885 +0.6931376457214355 +0.6931767463684082 +0.6931583881378174 +0.6931485533714294 +0.693138837814331 +0.6931250095367432 +0.693103015422821 +0.6931023597717285 +0.6932260990142822 +0.6931752562522888 +0.6930729150772095 +0.6929311156272888 +0.693302571773529 +0.6932254433631897 +0.69317626953125 +0.693097710609436 +0.6930376291275024 +0.6931532621383667 +0.6931279301643372 +0.6931777596473694 +0.6931577324867249 +0.6931435465812683 +0.6931730508804321 +0.693141520023346 +0.6931696534156799 +0.6931543350219727 +0.6931476593017578 +0.6931471824645996 +0.6931589245796204 +0.693145751953125 +0.6931431293487549 +0.6931287050247192 +0.6931427717208862 +0.6931363344192505 +0.6931345462799072 +0.6931136250495911 +0.6930984258651733 +0.6931260228157043 +0.6932109594345093 +0.6931638121604919 +0.6931529641151428 +0.6931443214416504 +0.6931478381156921 +0.6931700110435486 +0.69312983751297 +0.6932106614112854 +0.6930972933769226 +0.6931629776954651 +0.6931963562965393 +0.6932249665260315 +0.6932281851768494 +0.6932195425033569 +0.6931582093238831 +0.6931502819061279 +0.693153440952301 +0.6930547952651978 +0.6932091116905212 +0.6930832862854004 +0.69318687915802 +0.693234384059906 +0.6931787133216858 +0.6931472420692444 +0.6931833624839783 +0.6931379437446594 +0.6931558847427368 +0.693196713924408 +0.6931143999099731 +0.693136990070343 +0.6931957602500916 +0.6931578516960144 +0.6931463479995728 +0.6931509375572205 +0.6931226253509521 +0.6931785941123962 +0.6931405663490295 +0.6931736469268799 +0.6931595206260681 +0.6931319236755371 +0.6931323409080505 +0.6931301355361938 +0.6931783556938171 +0.6931540966033936 +0.6930714249610901 +0.693152904510498 +0.6931881904602051 +0.6931595206260681 +0.6931363940238953 +0.6931393146514893 +0.6931549310684204 +0.6931518316268921 +0.6931600570678711 +0.6931359767913818 +0.693086564540863 +0.6930826306343079 +0.693168044090271 +0.6931942105293274 +0.6932410001754761 +0.693097710609436 +0.693099856376648 +0.69315505027771 +0.693153977394104 +0.6931472420692444 +0.6931328177452087 +0.6931746602058411 +0.6931381821632385 +0.6931582689285278 +0.6933059692382812 +0.6930915117263794 +0.6931243538856506 +0.6934514045715332 +0.6933988928794861 +0.6932798624038696 +0.6931632161140442 +0.6931505799293518 +0.6931473016738892 +0.6931563019752502 +0.6931017637252808 +0.6932226419448853 +0.6932034492492676 +0.6931058764457703 +0.6932246088981628 +0.6930988430976868 +0.6931736469268799 +0.6931524276733398 +0.6931332945823669 +0.6931236386299133 +0.6931801438331604 +0.6931136250495911 +0.6931392550468445 +0.6931288838386536 +0.6931090950965881 +0.6931648254394531 \ No newline at end of file diff --git a/examples/demo/little_demo/deterministic_loss/loss1 b/examples/demo/little_demo/deterministic_loss/loss1 new file mode 100644 index 0000000000000000000000000000000000000000..cfe29fc934de0555ca75afe6b496d22319bdef41 --- /dev/null +++ b/examples/demo/little_demo/deterministic_loss/loss1 @@ -0,0 +1,200 @@ +0.6931475400924683 +0.6930400133132935 +0.693139910697937 +0.6931551098823547 +0.6931850910186768 +0.6931071877479553 +0.6931338310241699 +0.6931014060974121 +0.6931450963020325 +0.69313645362854 +0.6930742263793945 +0.6931931376457214 +0.6931841373443604 +0.6931138038635254 +0.6932246685028076 +0.6931509971618652 +0.6931785941123962 +0.693133533000946 +0.6931544542312622 +0.6931360363960266 +0.6931753158569336 +0.6931651830673218 +0.6931511163711548 +0.6931532621383667 +0.6931378245353699 +0.6931488513946533 +0.6931437253952026 +0.6931431889533997 +0.693092942237854 +0.6931750178337097 +0.693172037601471 +0.6931487917900085 +0.6931712627410889 +0.6931683421134949 +0.6931533813476562 +0.6931492686271667 +0.6932083964347839 +0.6930453181266785 +0.6931129693984985 +0.6932074427604675 +0.6931206583976746 +0.6931827068328857 +0.6931033730506897 +0.6931606531143188 +0.6931372880935669 +0.69317227602005 +0.6931107044219971 +0.6931154727935791 +0.6931938529014587 +0.6932826638221741 +0.6932423710823059 +0.6931905746459961 +0.6931506395339966 +0.6931438446044922 +0.6931609511375427 +0.69315105676651 +0.6931482553482056 +0.6931400895118713 +0.6931483149528503 +0.6931435465812683 +0.6930944919586182 +0.6931304931640625 +0.6931834816932678 +0.6930789947509766 +0.6932030916213989 +0.693130373954773 +0.6933913826942444 +0.6931991577148438 +0.6931377649307251 +0.6931768655776978 +0.6931586861610413 +0.6931484341621399 +0.6931391358375549 +0.6931250691413879 +0.6931028366088867 +0.6931021213531494 +0.6932262182235718 +0.6931752562522888 +0.6930727362632751 +0.6929311156272888 +0.6933025121688843 +0.6932255625724792 +0.6931764483451843 +0.6930979490280151 +0.6930376887321472 +0.6931535005569458 +0.6931277513504028 +0.6931778788566589 +0.6931575536727905 +0.6931436657905579 +0.6931729316711426 +0.6931415796279907 +0.6931697726249695 +0.6931543946266174 +0.6931476593017578 +0.6931473016738892 +0.6931586861610413 +0.6931456923484802 +0.6931430697441101 +0.6931284070014954 +0.693142831325531 +0.6931363940238953 +0.6931345462799072 +0.6931135058403015 +0.6930984258651733 +0.6931260228157043 +0.6932108998298645 +0.6931638717651367 +0.6931529641151428 +0.6931443810462952 +0.6931477785110474 +0.6931700110435486 +0.6931299567222595 +0.6932107210159302 +0.6930974125862122 +0.6931627988815308 +0.6931964159011841 +0.6932250261306763 +0.6932283043861389 +0.6932194828987122 +0.6931582093238831 +0.6931501626968384 +0.693153440952301 +0.6930548548698425 +0.6932091116905212 +0.6930834650993347 +0.6931867599487305 +0.6932343244552612 +0.6931787133216858 +0.6931471824645996 +0.6931833028793335 +0.6931377649307251 +0.6931559443473816 +0.693196713924408 +0.6931144595146179 +0.6931368708610535 +0.6931958198547363 +0.6931577920913696 +0.6931461691856384 +0.6931511163711548 +0.6931224465370178 +0.693178653717041 +0.6931405663490295 +0.6931737661361694 +0.6931594014167786 +0.6931319236755371 +0.6931324005126953 +0.6931299567222595 +0.6931784152984619 +0.6931542754173279 +0.6930714845657349 +0.693152666091919 +0.6931881308555603 +0.6931596994400024 +0.6931365132331848 +0.6931394338607788 +0.6931548714637756 +0.6931518316268921 +0.6931599974632263 +0.6931360363960266 +0.6930868029594421 +0.6930827498435974 +0.6931679844856262 +0.6931941509246826 +0.6932410001754761 +0.693097710609436 +0.693099856376648 +0.6931549906730652 +0.6931538581848145 +0.6931471824645996 +0.693132758140564 +0.6931745409965515 +0.6931381225585938 +0.6931583881378174 +0.6933057904243469 +0.693091630935669 +0.6931243538856506 +0.6934512853622437 +0.6933985948562622 +0.6932798624038696 +0.6931629180908203 +0.6931505799293518 +0.6931473612785339 +0.6931563019752502 +0.6931016445159912 +0.6932225227355957 +0.6932035088539124 +0.693105936050415 +0.6932247877120972 +0.6930989027023315 +0.6931736469268799 +0.6931525468826294 +0.6931331753730774 +0.6931236982345581 +0.69318026304245 +0.6931138038635254 +0.6931390762329102 +0.6931287050247192 +0.6931091547012329 +0.6931648850440979 \ No newline at end of file diff --git a/examples/demo/little_demo/deterministic_loss/loss2 b/examples/demo/little_demo/deterministic_loss/loss2 new file mode 100644 index 0000000000000000000000000000000000000000..cfe29fc934de0555ca75afe6b496d22319bdef41 --- /dev/null +++ b/examples/demo/little_demo/deterministic_loss/loss2 @@ -0,0 +1,200 @@ +0.6931475400924683 +0.6930400133132935 +0.693139910697937 +0.6931551098823547 +0.6931850910186768 +0.6931071877479553 +0.6931338310241699 +0.6931014060974121 +0.6931450963020325 +0.69313645362854 +0.6930742263793945 +0.6931931376457214 +0.6931841373443604 +0.6931138038635254 +0.6932246685028076 +0.6931509971618652 +0.6931785941123962 +0.693133533000946 +0.6931544542312622 +0.6931360363960266 +0.6931753158569336 +0.6931651830673218 +0.6931511163711548 +0.6931532621383667 +0.6931378245353699 +0.6931488513946533 +0.6931437253952026 +0.6931431889533997 +0.693092942237854 +0.6931750178337097 +0.693172037601471 +0.6931487917900085 +0.6931712627410889 +0.6931683421134949 +0.6931533813476562 +0.6931492686271667 +0.6932083964347839 +0.6930453181266785 +0.6931129693984985 +0.6932074427604675 +0.6931206583976746 +0.6931827068328857 +0.6931033730506897 +0.6931606531143188 +0.6931372880935669 +0.69317227602005 +0.6931107044219971 +0.6931154727935791 +0.6931938529014587 +0.6932826638221741 +0.6932423710823059 +0.6931905746459961 +0.6931506395339966 +0.6931438446044922 +0.6931609511375427 +0.69315105676651 +0.6931482553482056 +0.6931400895118713 +0.6931483149528503 +0.6931435465812683 +0.6930944919586182 +0.6931304931640625 +0.6931834816932678 +0.6930789947509766 +0.6932030916213989 +0.693130373954773 +0.6933913826942444 +0.6931991577148438 +0.6931377649307251 +0.6931768655776978 +0.6931586861610413 +0.6931484341621399 +0.6931391358375549 +0.6931250691413879 +0.6931028366088867 +0.6931021213531494 +0.6932262182235718 +0.6931752562522888 +0.6930727362632751 +0.6929311156272888 +0.6933025121688843 +0.6932255625724792 +0.6931764483451843 +0.6930979490280151 +0.6930376887321472 +0.6931535005569458 +0.6931277513504028 +0.6931778788566589 +0.6931575536727905 +0.6931436657905579 +0.6931729316711426 +0.6931415796279907 +0.6931697726249695 +0.6931543946266174 +0.6931476593017578 +0.6931473016738892 +0.6931586861610413 +0.6931456923484802 +0.6931430697441101 +0.6931284070014954 +0.693142831325531 +0.6931363940238953 +0.6931345462799072 +0.6931135058403015 +0.6930984258651733 +0.6931260228157043 +0.6932108998298645 +0.6931638717651367 +0.6931529641151428 +0.6931443810462952 +0.6931477785110474 +0.6931700110435486 +0.6931299567222595 +0.6932107210159302 +0.6930974125862122 +0.6931627988815308 +0.6931964159011841 +0.6932250261306763 +0.6932283043861389 +0.6932194828987122 +0.6931582093238831 +0.6931501626968384 +0.693153440952301 +0.6930548548698425 +0.6932091116905212 +0.6930834650993347 +0.6931867599487305 +0.6932343244552612 +0.6931787133216858 +0.6931471824645996 +0.6931833028793335 +0.6931377649307251 +0.6931559443473816 +0.693196713924408 +0.6931144595146179 +0.6931368708610535 +0.6931958198547363 +0.6931577920913696 +0.6931461691856384 +0.6931511163711548 +0.6931224465370178 +0.693178653717041 +0.6931405663490295 +0.6931737661361694 +0.6931594014167786 +0.6931319236755371 +0.6931324005126953 +0.6931299567222595 +0.6931784152984619 +0.6931542754173279 +0.6930714845657349 +0.693152666091919 +0.6931881308555603 +0.6931596994400024 +0.6931365132331848 +0.6931394338607788 +0.6931548714637756 +0.6931518316268921 +0.6931599974632263 +0.6931360363960266 +0.6930868029594421 +0.6930827498435974 +0.6931679844856262 +0.6931941509246826 +0.6932410001754761 +0.693097710609436 +0.693099856376648 +0.6931549906730652 +0.6931538581848145 +0.6931471824645996 +0.693132758140564 +0.6931745409965515 +0.6931381225585938 +0.6931583881378174 +0.6933057904243469 +0.693091630935669 +0.6931243538856506 +0.6934512853622437 +0.6933985948562622 +0.6932798624038696 +0.6931629180908203 +0.6931505799293518 +0.6931473612785339 +0.6931563019752502 +0.6931016445159912 +0.6932225227355957 +0.6932035088539124 +0.693105936050415 +0.6932247877120972 +0.6930989027023315 +0.6931736469268799 +0.6931525468826294 +0.6931331753730774 +0.6931236982345581 +0.69318026304245 +0.6931138038635254 +0.6931390762329102 +0.6931287050247192 +0.6931091547012329 +0.6931648850440979 \ No newline at end of file diff --git a/examples/demo/little_demo/deterministic_loss/loss3 b/examples/demo/little_demo/deterministic_loss/loss3 new file mode 100644 index 0000000000000000000000000000000000000000..a38ce81cd493edd9017844b477d93cb0c3004ecd --- /dev/null +++ b/examples/demo/little_demo/deterministic_loss/loss3 @@ -0,0 +1,200 @@ +0.6931473016738892 +0.6930400729179382 +0.6931402087211609 +0.69315505027771 +0.6931849122047424 +0.6931070685386658 +0.6931337118148804 +0.6931014060974121 +0.693144679069519 +0.6931362748146057 +0.6930745840072632 +0.6931930780410767 +0.6931840777397156 +0.6931135654449463 +0.6932245492935181 +0.6931509375572205 +0.6931784152984619 +0.6931337714195251 +0.693154513835907 +0.6931360363960266 +0.6931752562522888 +0.6931653022766113 +0.6931512355804443 +0.6931530833244324 +0.6931378841400146 +0.6931486129760742 +0.6931437253952026 +0.6931434273719788 +0.6930927634239197 +0.6931749582290649 +0.6931719779968262 +0.6931490302085876 +0.6931714415550232 +0.6931683421134949 +0.6931533217430115 +0.6931492686271667 +0.6932084560394287 +0.6930454969406128 +0.6931130290031433 +0.6932073831558228 +0.6931207776069641 +0.6931827664375305 +0.693103551864624 +0.6931607127189636 +0.6931374669075012 +0.69317227602005 +0.6931108832359314 +0.6931152939796448 +0.6931939125061035 +0.6932826638221741 +0.6932422518730164 +0.6931905746459961 +0.693150520324707 +0.6931438446044922 +0.693160891532898 +0.6931508779525757 +0.693148136138916 +0.6931400299072266 +0.6931481957435608 +0.6931434869766235 +0.6930946111679077 +0.6931304335594177 +0.693183422088623 +0.6930789947509766 +0.6932030320167542 +0.6931302547454834 +0.6933913826942444 +0.6931991577148438 +0.6931378841400146 +0.6931770443916321 +0.6931586265563965 +0.6931484937667847 +0.6931388974189758 +0.6931250095367432 +0.693103015422821 +0.6931023001670837 +0.6932262182235718 +0.6931753754615784 +0.6930726170539856 +0.6929312944412231 +0.693302571773529 +0.6932252645492554 +0.6931763291358948 +0.693097710609436 +0.6930376291275024 +0.6931533217430115 +0.6931278705596924 +0.6931778788566589 +0.6931576728820801 +0.6931436657905579 +0.6931729912757874 +0.6931415796279907 +0.6931697130203247 +0.6931542754173279 +0.6931475400924683 +0.6931473016738892 +0.6931589245796204 +0.6931456923484802 +0.6931431293487549 +0.6931285858154297 +0.693142831325531 +0.6931363344192505 +0.6931344270706177 +0.6931136250495911 +0.6930983066558838 +0.6931259632110596 +0.693211019039154 +0.6931636929512024 +0.6931530237197876 +0.6931443214416504 +0.6931476593017578 +0.6931700706481934 +0.69312983751297 +0.6932106614112854 +0.6930974125862122 +0.6931630373001099 +0.6931962370872498 +0.6932251453399658 +0.6932281851768494 +0.6932194828987122 +0.6931582093238831 +0.6931502819061279 +0.693153440952301 +0.6930545568466187 +0.693209171295166 +0.6930832862854004 +0.6931869387626648 +0.6932346224784851 +0.693178653717041 +0.6931472420692444 +0.6931833624839783 +0.6931377649307251 +0.6931559443473816 +0.6931968331336975 +0.6931143999099731 +0.6931371092796326 +0.6931959390640259 +0.6931577324867249 +0.6931463479995728 +0.6931511759757996 +0.6931225061416626 +0.6931787133216858 +0.6931406259536743 +0.6931735873222351 +0.6931595206260681 +0.6931317448616028 +0.6931322813034058 +0.6931302547454834 +0.6931782960891724 +0.6931543946266174 +0.6930716037750244 +0.6931528449058533 +0.6931881904602051 +0.6931595802307129 +0.6931363940238953 +0.6931394934654236 +0.6931549906730652 +0.6931518912315369 +0.6931601762771606 +0.6931359767913818 +0.6930863261222839 +0.6930826902389526 +0.693168044090271 +0.6931941509246826 +0.6932411193847656 +0.6930977702140808 +0.6930997967720032 +0.6931549906730652 +0.6931539177894592 +0.6931472420692444 +0.6931326985359192 +0.6931745409965515 +0.6931379437446594 +0.6931582689285278 +0.6933060884475708 +0.693091630935669 +0.6931243538856506 +0.6934512853622437 +0.693398654460907 +0.6932798624038696 +0.6931633353233337 +0.6931505799293518 +0.6931473612785339 +0.6931564211845398 +0.693101704120636 +0.6932228207588196 +0.6932032704353333 +0.6931060552597046 +0.6932246685028076 +0.6930988430976868 +0.6931737065315247 +0.6931525468826294 +0.6931332945823669 +0.6931236386299133 +0.6931802034378052 +0.6931136846542358 +0.6931393146514893 +0.6931288242340088 +0.6931090950965881 +0.6931647658348083 \ No newline at end of file diff --git a/examples/demo/little_demo/main.py b/examples/demo/little_demo/main.py index 8813de445435390368364c99c181c335a91618ed..cfaecbde4493b66795b3d0fcf9cc293c309ae34b 100644 --- a/examples/demo/little_demo/main.py +++ b/examples/demo/little_demo/main.py @@ -21,9 +21,10 @@ import shutil import warnings from glob import glob +import numpy as np import tensorflow as tf -from mx_rec.constants.constants import ASCEND_TIMESTAMP +from mx_rec.constants.constants import ASCEND_TIMESTAMP, CacheModeEnum from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline @@ -45,12 +46,6 @@ tf.compat.v1.disable_eager_execution() _SSD_SAVE_PATH = ["ssd_data"] # user should make sure directory exist and clean before training -class CacheModeEnum(enum.Enum): - HBM = "HBM" - DDR = "DDR" - SSD = "SSD" - - def make_batch_and_iterator(is_training, feature_spec_list=None, use_timestamp=False, dump_graph=False, batch_number=100): dataset = generate_dataset(cfg, use_timestamp=use_timestamp, batch_number=batch_number) @@ -147,34 +142,45 @@ def create_feature_spec_list(use_timestamp=False): return feature_spec_list -def clear_saved_model(): +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + _del_related_dir("kernel*") + _del_related_dir("export_graph") + mode = UseMode.mapping(os.getenv("USE_MODE")) - if mode == UseMode.TRAIN: - logger.info("current mode is train, will delete previous saved model data if exist.") - save_model_path = os.path.join(os.getcwd(), "saved-model") - shutil.rmtree(save_model_path, ignore_errors=True) - if not (os.getenv("CACHE_MODE", "") == CacheModeEnum.SSD.value and mode == UseMode.TRAIN): + if mode != UseMode.TRAIN: return + logger.info("current mode is train, will delete previous saved model data if exist.") + _del_related_dir("saved-model") - # ssd not allow overwrite file, should clear it before training - logger.info("current cache mode is SSD, will delete previous saved ssd data if exist.") - for part_path in _SSD_SAVE_PATH: - if "/" not in part_path and "\\" not in part_path: - part_path = os.path.join(os.getcwd(), part_path) - shutil.rmtree(part_path, ignore_errors=True) - try: - os.mkdir(part_path) - except OSError: - logger.warning("ssd path has exist") # 多进程并行,忽略异常 + if not (os.getenv("CACHE_MODE", "") == CacheModeEnum.SSD.value): + return + logger.info("current cache mode is SSD, and file overwrite is not allowed in SSD mode, deleting exist directory" + " then create empty directory for this use case.") + for sub_path in _SSD_SAVE_PATH: + _del_related_dir(sub_path) + os.makedirs(sub_path, mode=0o550, exist_ok=True) + logger.info(f"Create dir:{sub_path}") if __name__ == "__main__": tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings("ignore") + _clear_saved_model() use_mode = UseMode.mapping(os.getenv("USE_MODE")) # 最大数据集生成数量 - MAX_DATASET_GENERATE = 200 + MAX_DATASET_GENERATE_TRAIN = 200 + MAX_DATASET_GENERATE_EVAL = 10 # 最大训练的步数 MAX_TRAIN_STEPS = 200 # 训练多少步切换为评估 @@ -187,21 +193,25 @@ if __name__ == "__main__": # get init configuration try: use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) - use_hot = bool(int(os.getenv("USE_HOT", 0))) use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) USE_ONE_SHOT = bool(int(os.getenv("USE_ONE_SHOT", 0))) + USE_DETERMINISTIC = bool(int(os.getenv("USE_DETERMINISTIC", 0))) except ValueError as err: - raise ValueError(f"please correctly config USE_MPI or USE_DYNAMIC or USE_HOT or USE_DYNAMIC_EXPANSION or " - f"USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP or USE_ONE_SHOT " - f"only 0 or 1 is supported.") from err + raise ValueError("please correctly config USE_MPI or USE_DYNAMIC or USE_DYNAMIC_EXPANSION or " + "USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP or USE_ONE_SHOT or USE_DETERMINISTIC" + "only 0 or 1 is supported.") from err try: MULTI_LOOKUP_TIMES = int(os.getenv("MULTI_LOOKUP_TIMES", 2)) except ValueError as err: - raise ValueError(f"please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err + raise ValueError("please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err + + if USE_DETERMINISTIC: + np.random.seed(128) + tf.random.set_random_seed(128) if_load = False save_path = "./saved-model" @@ -212,13 +222,13 @@ if __name__ == "__main__": if len(model_file) == 0: raise ValueError(f"get USE_MODE:{use_mode}, but no model file exist at:{load_path_pattern}") if_load = True - + # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 init(train_steps=TRAIN_STEPS, eval_steps=EVAL_STEPS, save_steps=SAVING_INTERVAL, + max_steps=MAX_TRAIN_STEPS, use_dynamic=use_dynamic, - use_hot=use_hot, use_dynamic_expansion=use_dynamic_expansion, if_load=if_load) @@ -242,18 +252,17 @@ if __name__ == "__main__": eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) optimizer_list = [create_dense_and_sparse_optimizer(cfg)] - sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] # 如需验证DDR模式,请按照key数量、batch unique数量合理设置device与host表大小。 # 验证DDR的配置参考:建议跑dynamic避免调参。数据集key总量大于device表,小于device+host;一个batch的unique key数量小于device表。 # 验证SSD的配置参考:建议跑dynamic避免调参。数据集key总量大于device+host;一个batch的unique key数量小于device表。 hbm_test_cfg = {"device_vocabulary_size": cfg.user_vocab_size, "host_vocabulary_size": 0} - ddr_test_cfg = {"device_vocabulary_size": int(cfg.user_vocab_size * 0.2), - "host_vocabulary_size": int(cfg.user_vocab_size * 0.8)} + ddr_test_cfg = {"device_vocabulary_size": int(cfg.user_vocab_size * 0.4), + "host_vocabulary_size": int(cfg.user_vocab_size * 1.0)} ssd_test_cfg = { - "device_vocabulary_size": int(cfg.user_vocab_size * 0.1), - "host_vocabulary_size": int(cfg.user_vocab_size * 0.1), - "ssd_vocabulary_size": int(cfg.user_vocab_size * 0.8), "ssd_data_path": _SSD_SAVE_PATH + "device_vocabulary_size": int(cfg.user_vocab_size * 0.4), + "host_vocabulary_size": int(cfg.user_vocab_size * 0.8), + "ssd_vocabulary_size": int(cfg.user_vocab_size * 1.8), "ssd_data_path": _SSD_SAVE_PATH } cache_mode_dict = {CacheModeEnum.HBM.value: hbm_test_cfg, CacheModeEnum.DDR.value: ddr_test_cfg, CacheModeEnum.SSD.value: ssd_test_cfg} @@ -263,20 +272,19 @@ if __name__ == "__main__": raise ValueError(f"cache mode must in {list(cache_mode_dict.keys())}, get:{cache_mode}") if cache_mode in ["DDR", "SSD"] and not use_dynamic: logger.warning("when cache_mode in [DDR, SSD], suggest use_dynamic=true to avoid tuning size parameter") - + emb_initializer = tf.compat.v1.constant_initializer(0) if USE_DETERMINISTIC \ + else tf.compat.v1.truncated_normal_initializer() user_hashtable = create_table(key_dtype=tf.int64, dim=tf.TensorShape([cfg.user_hashtable_dim]), name='user_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer(), - optimizer_list=sparse_optimizer_list, + emb_initializer=emb_initializer, all2all_gradients_op="sum_gradients_and_div_by_ranksize", **cache_mode_dict[cache_mode]) item_hashtable = create_table(key_dtype=tf.int64, dim=tf.TensorShape([cfg.item_hashtable_dim]), name='item_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer(), - optimizer_list=sparse_optimizer_list, + emb_initializer=emb_initializer, **cache_mode_dict[cache_mode]) # 在predict的场景下,train model不需要被执行 @@ -285,17 +293,20 @@ if __name__ == "__main__": train_batch = None table_list = [user_hashtable, item_hashtable] if use_mode in [UseMode.TRAIN, UseMode.LOAD_AND_TRAIN]: - train_iterator, train_model, train_batch = build_graph(table_list, is_train=True, - feature_spec_list=train_feature_spec_list, - config_dict=ACCESS_AND_EVICT, - batch_number=MAX_DATASET_GENERATE * get_rank_size()) + train_iterator, train_model, train_batch = build_graph( + table_list, is_train=True, + feature_spec_list=train_feature_spec_list, + config_dict=ACCESS_AND_EVICT, + batch_number=MAX_DATASET_GENERATE_TRAIN * get_rank_size() + ) eval_iterator, eval_model, eval_batch = build_graph(table_list, is_train=False, feature_spec_list=eval_feature_spec_list, config_dict=ACCESS_AND_EVICT, - batch_number=MAX_DATASET_GENERATE * get_rank_size()) + batch_number=MAX_DATASET_GENERATE_EVAL * get_rank_size()) dense_variables, sparse_variables = get_dense_and_sparse_variable() - params = {"train_batch": train_batch, "eval_batch": eval_batch, "use_one_shot": USE_ONE_SHOT} + params = {"train_batch": train_batch, "eval_batch": eval_batch, "use_one_shot": USE_ONE_SHOT, + "use_deterministic": USE_DETERMINISTIC} run_mode = RunMode( MODIFY_GRAPH_FLAG, USE_TIMESTAMP, table_list, optimizer_list, train_model, eval_model, train_iterator, eval_iterator, MAX_TRAIN_STEPS, EVAL_STEPS, params diff --git a/examples/demo/little_demo/run.sh b/examples/demo/little_demo/run.sh index ab74adb2b79333e45adfaad1228d3d2edb2b0d16..5c5d9d1d730bafdcb68367e86ad3879511db0c5e 100644 --- a/examples/demo/little_demo/run.sh +++ b/examples/demo/little_demo/run.sh @@ -15,26 +15,12 @@ # ============================================================================== kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 -rm -rf /root/ascend/log/* -rm -rf ./kernel* -rm -rf ./export_graph/* # 支持[train, load_and_train, predict] -export USE_MODE="train" -if [ $USE_MODE = "train" ]; then - echo "train mode: saved-model will be deleted" - rm -rf ./saved-model -fi +export USE_MODE="train" # if train mode, will remove dir ./saved-model # cache mode support: HBM, DDR, SSD export CACHE_MODE="HBM" -if [ $CACHE_MODE = "SSD" ] && [ $USE_MODE = "train" ]; then - echo "SSD train mode not allow file exist in directory when training a model from stratch in case overwrite, - deleting directory ssd_data then create for this use case" - rm -rf ssd_data - mkdir ssd_data -fi - # 获取输入参数:py、ip if [ $# -ge 1 ]; then @@ -100,15 +86,13 @@ export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL # 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL export MXREC_MODE="ASC" -export USE_MPI=1 ################# 参数配置 ###################### export USE_DYNAMIC=1 # 0:静态shape;1:动态shape -export USE_HOT=0 # 0:关闭hot emb;1: 开启hot emb export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export MULTI_LOOKUP_TIMES=2 # 一表多查次数:默认2,上限127(因为一表已经有一查);仅当export USE_MULTI_LOOKUP=1时生效 -export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +export USE_MODIFY_GRAPH=1 # 0:feature spec模式;1:自动改图模式 export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 export USE_ONE_SHOT=0 # 0:MakeIterator;1:OneShotIterator export UpdateEmb_V2=1 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 @@ -162,7 +146,6 @@ else echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" echo "CM_WORKER_IP=$CM_WORKER_IP" echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" - echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" ######################################################### else echo "ip: $ip not available!" # 使用ranktable方案 @@ -177,5 +160,5 @@ fi echo "use horovod to start tasks" DATE=$(date +%Y-%m-%d-%H-%M-%S) horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ -python3.7 ${py} 2>&1 | tee "temp_${local_rank_size}p_${KEY_PROCESS_THREAD_NUM}t_${USE_MODE}_${CACHE_MODE}_${DATE}.log" +python3.7 ${py} 2>&1 | tee "temp_${num_process}p_${KEY_PROCESS_THREAD_NUM}t_${USE_MODE}_${CACHE_MODE}_${DATE}.log" diff --git a/examples/demo/little_demo/run_deterministic.sh b/examples/demo/little_demo/run_deterministic.sh new file mode 100644 index 0000000000000000000000000000000000000000..fbb4342da388287751e8d7e79d95d03a0f91f5bd --- /dev/null +++ b/examples/demo/little_demo/run_deterministic.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +export USE_DETERMINISTIC=1 + +sh run.sh main.py | tee log + +grep -rn "loss" log | grep "1,0" | awk '{print $NF}'> loss + +rm -f log + +soc_name=`python3 -c 'import acl;print(acl.get_soc_name())'` +echo "soc_name: $soc_name" + +loss_file=deterministic_loss/loss${soc_name:10:1} + +if [ ! -e $loss_file ];then + echo "$loss_file file does not exist" + rm -f loss + exit +fi + + +diff $loss_file loss + +if [ $? -eq 0 ]; then + echo "deterministic loss check passed" +else + echo "deterministic loss check failed" +fi + +rm -f loss diff --git a/examples/demo/little_demo/run_mode.py b/examples/demo/little_demo/run_mode.py index 0f7a8cc4827d1052105af083605f61caa40c564d..1a15fcc61b7d82b17847d36c1076234d2b7c8074 100644 --- a/examples/demo/little_demo/run_mode.py +++ b/examples/demo/little_demo/run_mode.py @@ -16,6 +16,7 @@ # ============================================================================== import os +import sys from typing import List import tensorflow as tf @@ -44,7 +45,9 @@ class RunMode: eval_model, train_iterator, eval_iterator, max_train_steps: int, infer_steps: int, params: dict): self.is_modify_graph = is_modify_graph self.is_faae = is_faae - self.session = tf.compat.v1.Session(config=sess_config(dump_data=False)) + self.use_deterministic = params.get("use_deterministic") + self.session = tf.compat.v1.Session( + config=sess_config(dump_data=False, use_deterministic=self.use_deterministic)) self.train_model = train_model self.train_iterator = train_iterator self.eval_model = eval_model @@ -70,12 +73,14 @@ class RunMode: channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(False) import_host_pipeline_ops().clear_channel(channel_id) + if self.infer_steps == -1: + self.infer_steps = sys.maxsize # 消耗全部数据 for i in range(1, self.infer_steps + 1): logger.info("############### infer at step %d ################", i) try: self.session.run(self.eval_model.loss_list) except tf.errors.OutOfRangeError: - logger.info(f"Encounter the end of Sequence for eval.") + logger.info("Encounter the end of Sequence for eval.") break def set_train_ops(self): @@ -95,11 +100,11 @@ class RunMode: self.train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))): - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) # do sparse optimization by addr local_grads = tf.gradients(loss, train_emb_list) # local_embedding @@ -124,36 +129,40 @@ class RunMode: self.session.run(initializer) else: logger.debug(f"use one shot iterator and modify graph is `{self.is_modify_graph}`.") - self.saver = tf.compat.v1.train.Saver() - start_step = 1 + latest_ckpt_step = 0 + start_step = 1 if if_load: - latest_step = get_load_step(model_file) - start_step = latest_step + 1 - self.saver.restore(self.session, f"./saved-model/model-{latest_step}") + latest_ckpt_step = get_load_step(model_file) + start_step = latest_ckpt_step + 1 + self.saver.restore(self.session, f"./saved-model/model-{latest_ckpt_step}") else: self.session.run(tf.compat.v1.global_variables_initializer()) + if self.max_train_steps == -1: + self.max_train_steps = sys.maxsize # 消耗全部数据 for i in range(start_step, start_step + self.max_train_steps): logger.info("################ training at step %d ################", i) try: - self.session.run([self.train_ops, self.train_model.loss_list]) + _, loss = self.session.run([self.train_ops, self.train_model.loss_list]) + if self.use_deterministic: + logger.info(f"train_loss: {loss[0]}") except tf.errors.OutOfRangeError: - logger.info(f"Encounter the end of Sequence for training.") + logger.info("Encounter the end of Sequence for training.") break else: for t in self.table_list: logger.info(f"training at step:{i}, table[{t.table_name}], table size:{t.size()}, " f"table capacity:{t.capacity()}") - if i % train_interval == 0: + if train_interval != -1 and (i - latest_ckpt_step) % train_interval == 0: self.evaluate() - if i % saving_interval == 0: + if saving_interval != -1 and (i - latest_ckpt_step) % saving_interval == 0: self.saver.save(self.session, f"./saved-model/model", global_step=i) - if self.is_faae and i == train_interval // 2: + if train_interval != -1 and self.is_faae and i == train_interval // 2: logger.info("############### set_threshold at step:%d ################", i) self.change_threshold() @@ -170,14 +179,14 @@ class RunMode: self.epoch += 1 def predict(self, model_file: List[str]): - logger.info(f"############### start predict ################") + logger.info("############### start predict ################") # get the latest model latest_step = get_load_step(model_file) self.saver = tf.compat.v1.train.Saver() self.saver.restore(self.session, f"./saved-model/model-{latest_step}") self._infer() - logger.info(f"############### predict end ################") + logger.info("############### predict end ################") def change_threshold(self): thres_tensor = tf.constant(60, dtype=tf.int32) diff --git a/examples/demo/little_demo_estimator/README.md b/examples/demo/little_demo_estimator/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aca25a34f4bbe3b06ba2f4f4683ec6784160fb72 --- /dev/null +++ b/examples/demo/little_demo_estimator/README.md @@ -0,0 +1,57 @@ +# estimator模式下demo模型运行说明 + +## 代码结构 +```shell +. +├── config.py # 模型配置文件 +├── dataset.py # 生成数据集的脚本 +├── main.py # 主函数 +├── nn_model_build.py # demo模型 +├── nn_model_input.py # 定义model_fn +├── nn_optim.py # 定义train的各个op +├── nn_reader.py # 定义input_fn +├── op_precision.ini # 算子执行模式配置 +├── random_data_generator.py # 数据生成器 +├── README.md # demo模型运行说明 +├── run.sh # demo运行脚本 +├── tf_adapter.py # 导入tf adapter +└── utils.py # 公共函数 +``` + +## 1.准备数据 +demo样例无需从其他地方下载数据集,在demo样例中mxRec会自动生成数据集,详情见dataset.py和random_data_generator.py。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行demo模型 +执行完以上步骤之后,接下来就可以运行demo模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入ip这个参数,运行命令如: +```shell +bash run.sh main.py {ip} +``` +* ip:ip是运行模型的机器所在的ip。 + +**Tips**:run.sh脚本中有一个参数是mx_rec_package_path,mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +这个参数在脚本是默认的,用户需要根据自己环境中mxRec实际安装的路径进行配置。 \ No newline at end of file diff --git a/examples/demo/little_demo_estimator/main.py b/examples/demo/little_demo_estimator/main.py index 901bf23a3c545fc9f1bdda339c62f76dd4ceef7f..716e40d05f69e9b28722f2724d6d4d1b15f49026 100644 --- a/examples/demo/little_demo_estimator/main.py +++ b/examples/demo/little_demo_estimator/main.py @@ -17,18 +17,18 @@ import argparse import os +import shutil +from glob import glob import tensorflow as tf - from mx_rec.util.initialize import init, terminate_config_initializer -from mx_rec.util.communication.hccl_ops import get_rank_id from mx_rec.core.asc.helper import FeatureSpec from mx_rec.graph.modifier import GraphModifierHook -from mx_rec.graph.acg_push_ops import ACGPushOpsToDatasetHook +from mx_rec.graph.hooks import OrphanLookupKeySlicerHook, LookupSubgraphSlicerHook from mx_rec.core.feature_process import EvictHook from mx_rec.util.log import logger -from tf_adapter import NPURunConfig, NPUEstimator, npu_hooks_append, DumpConfig +from tf_adapter import NPURunConfig, NPUEstimator, npu_hooks_append from nn_reader import input_fn from nn_model_input import get_model_fn from config import Config @@ -37,7 +37,7 @@ from utils import FeatureSpecIns tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) -def main(params, cfg): +def main(params, config): mg_session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False) run_config = NPURunConfig( model_dir=params.model_dir, @@ -58,51 +58,58 @@ def main(params, cfg): # access_threshold unit counts; eviction_threshold unit seconds access_and_evict = None - if not params.enable_push_ops_test: + if not params.enable_slicer_test: hooks_list = [GraphModifierHook(modify_graph=params.modify_graph)] else: - hooks_list = [ACGPushOpsToDatasetHook(dump_graph=True), GraphModifierHook(modify_graph=params.modify_graph)] + orphan_slicer_hook = OrphanLookupKeySlicerHook() + lookup_slicer_hook = LookupSubgraphSlicerHook(op_types=["StringToNumber"]) + hooks_list = [orphan_slicer_hook, lookup_slicer_hook, GraphModifierHook(modify_graph=params.modify_graph)] if params.use_timestamp: - config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) - config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + config_for_user_table = dict(access_threshold=config.access_threshold, + eviction_threshold=config.eviction_threshold) + config_for_item_table = dict(access_threshold=config.access_threshold, + eviction_threshold=config.eviction_threshold) access_and_evict = dict(user_table=config_for_user_table, item_table=config_for_item_table) + evict_hook = EvictHook(evict_enable=True, evict_time_interval=10) hooks_list.append(evict_hook) - create_fs_params = dict(cfg=cfg, use_timestamp=params.use_timestamp, + create_fs_params = dict(cfg=config, use_timestamp=params.use_timestamp, use_multi_lookup=use_multi_lookup, multi_lookup_times=MULTI_LOOKUP_TIMES) est = NPUEstimator( - model_fn=get_model_fn(create_fs_params, cfg, access_and_evict), + model_fn=get_model_fn(create_fs_params, config, access_and_evict), params=params, model_dir=params.model_dir, config=run_config ) if params.run_mode == 'train': - est.train(input_fn=lambda: input_fn(params, create_fs_params, cfg), max_steps=params.max_steps, + est.train(input_fn=lambda: input_fn(params, create_fs_params, config), max_steps=params.max_steps, hooks=npu_hooks_append(hooks_list)) elif params.run_mode == 'train_and_evaluate': - train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(params, create_fs_params, cfg, + train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(params, create_fs_params, config, use_one_shot=args.use_one_shot), max_steps=params.max_steps, hooks=npu_hooks_append(hooks_list)) - # 在开启evict时,eval时不支持淘汰,所以无需加入evict hook - if not params.enable_push_ops_test: + if not params.enable_slicer_test: + # 在开启evict时,eval时不支持淘汰,所以无需加入evict hook eval_hook_list = [GraphModifierHook(modify_graph=params.modify_graph)] else: - eval_hook_list = [ACGPushOpsToDatasetHook(dump_graph=True), + orphan_slicer_hook = OrphanLookupKeySlicerHook() + lookup_slicer_hook = LookupSubgraphSlicerHook(op_types=["StringToNumber"]) + eval_hook_list = [orphan_slicer_hook, lookup_slicer_hook, GraphModifierHook(modify_graph=params.modify_graph)] - eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(params, create_fs_params, cfg, is_eval=True, + eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(params, create_fs_params, config, is_eval=True, use_one_shot=args.use_one_shot), steps=params.eval_steps, hooks=npu_hooks_append(eval_hook_list), throttle_secs=0) tf.estimator.train_and_evaluate(est, train_spec=train_spec, eval_spec=eval_spec) elif params.run_mode == 'predict': - results = est.predict(input_fn=lambda: input_fn(params, create_fs_params, cfg), + results = est.predict(input_fn=lambda: input_fn(params, create_fs_params, config), hooks=npu_hooks_append(hooks_list=hooks_list), yield_single_examples=False) output_pred1 = [] output_pred2 = [] @@ -138,6 +145,27 @@ def create_feature_spec_list(use_timestamp=False): return feature_spec_list +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + _del_related_dir("kernel*") + _del_related_dir("export_graph") + + mode = args.run_mode + if not mode.startswith("train"): + return + logger.info("current mode contains train, will delete previous saved model data if exist.") + _del_related_dir("_rank*") + + if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--run_mode', type=str, default='train_and_evaluate') # 运行模式,在run.sh中进行配置 @@ -158,22 +186,21 @@ if __name__ == '__main__': # get init configuration try: use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) - use_hot = bool(int(os.getenv("USE_HOT", 0))) use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) args.use_one_shot = bool(int(os.getenv("USE_ONE_SHOT", 0))) - args.enable_push_ops_test = bool(int(os.getenv("ENABLE_PUSH_OPS_TEST", 0))) + args.enable_slicer_test = bool(int(os.getenv("ENABLE_SLICER_TEST", 0))) except ValueError as err: - raise ValueError(f"please correctly config USE_MPI or USE_DYNAMIC or USE_HOT or USE_DYNAMIC_EXPANSION or " - f"USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP or USE_ONE_SHOT " - f"only 0 or 1 is supported.") from err + raise ValueError("please correctly config USE_MPI or USE_DYNAMIC or USE_DYNAMIC_EXPANSION or " + "USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP or USE_ONE_SHOT " + "only 0 or 1 is supported.") from err try: MULTI_LOOKUP_TIMES = int(os.getenv("MULTI_LOOKUP_TIMES", 2)) except ValueError as err: - raise ValueError(f"please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err + raise ValueError("please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err if args.run_mode == 'train': args.train_steps = -1 @@ -182,12 +209,14 @@ if __name__ == '__main__': args.eval_steps = -1 elif args.run_mode == 'train_and_evaluate': args.save_checkpoints_steps = args.train_steps + _clear_saved_model() # set init init(train_steps=args.train_steps, eval_steps=args.eval_steps, + save_steps=args.save_checkpoints_steps, + max_steps=args.max_steps, use_dynamic=use_dynamic, - use_hot=use_hot, use_dynamic_expansion=use_dynamic_expansion) args.model_dir = f"{args.model_ckpt_dir}_rank" diff --git a/examples/demo/little_demo_estimator/nn_model_build.py b/examples/demo/little_demo_estimator/nn_model_build.py index e715f930e5f30f0e9b894b0274ecaf67108df901..aeeab8f82b2672bc95f425ff6519e24cdf8f69a1 100644 --- a/examples/demo/little_demo_estimator/nn_model_build.py +++ b/examples/demo/little_demo_estimator/nn_model_build.py @@ -21,7 +21,6 @@ from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.constants.constants import ASCEND_TIMESTAMP -from nn_optim import get_dense_and_sparse_optimizer from utils import FeatureSpecIns @@ -137,25 +136,21 @@ class LittleModel: return logit_list def _get_embedding_list(self): - optimizer_list = [get_dense_and_sparse_optimizer(self.cfg)] - sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] user_hashtable = create_table(key_dtype=tf.int64, dim=tf.TensorShape([self.cfg.user_hashtable_dim]), name='user_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), device_vocabulary_size=self.cfg.user_vocab_size * 10, - host_vocabulary_size=self.cfg.user_vocab_size * 0, - optimizer_list=sparse_optimizer_list) + host_vocabulary_size=self.cfg.user_vocab_size * 0) item_hashtable = create_table(key_dtype=tf.int64, dim=tf.TensorShape([self.cfg.item_hashtable_dim]), name='item_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), device_vocabulary_size=self.cfg.item_vocab_size * 10, - host_vocabulary_size=self.cfg.item_vocab_size * 0, - optimizer_list=sparse_optimizer_list) + host_vocabulary_size=self.cfg.item_vocab_size * 0) if self.params.modify_graph: - if not self.params.enable_push_ops_test: + if not self.params.enable_slicer_test: input_list = [[self.features["user_ids"], self.features["item_ids"]], [user_hashtable, item_hashtable], [self.cfg.user_send_cnt, self.cfg.item_send_cnt], @@ -207,15 +202,16 @@ class LittleModel: return embedding_list -def _make_ids_with_const_ops(input: Tensor) -> Tensor: - const_ids = tf.constant(1, shape=input.shape, dtype=input.dtype) +def _make_ids_with_const_ops(input_tensor: Tensor) -> Tensor: + const_ids = tf.constant(1, shape=input_tensor.shape, dtype=input_tensor.dtype) const_ids = tf.compat.v1.add(const_ids, 1) const_ids = tf.compat.v1.subtract(const_ids, 1) return const_ids -def _make_ids_with_str_ops(input: Tensor) -> Tensor: - str_ids = tf.compat.v1.strings.as_string(input) + +def _make_ids_with_str_ops(input_tensor: Tensor) -> Tensor: + str_ids = tf.compat.v1.strings.as_string(input_tensor) str_ids = tf.compat.v1.strings.to_number(str_ids) - + return str_ids diff --git a/examples/demo/little_demo_estimator/nn_model_input.py b/examples/demo/little_demo_estimator/nn_model_input.py index 2ce70d412b5b3c1af1375cb5c26093c9260c2bf3..973a457cb75ce58673e4e76d5e39604e846aa6b8 100644 --- a/examples/demo/little_demo_estimator/nn_model_input.py +++ b/examples/demo/little_demo_estimator/nn_model_input.py @@ -17,10 +17,10 @@ import tensorflow as tf from mx_rec.constants.constants import ASCEND_TIMESTAMP +from mx_rec.util.log import logger from nn_model_build import LittleModel from nn_optim import get_train_op -from mx_rec.util.log import logger def get_model_fn(create_fs_params, cfg, access_and_evict_config_dict=None): @@ -29,7 +29,7 @@ def get_model_fn(create_fs_params, cfg, access_and_evict_config_dict=None): if params.use_timestamp: model = LittleModel(params, cfg, mode, features, create_fs_params, access_and_evict_config_dict=access_and_evict_config_dict) - tf.add_to_collection(ASCEND_TIMESTAMP, features["timestamp"]) + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, features["timestamp"]) else: model = LittleModel(params, cfg, mode, features, create_fs_params) else: @@ -39,19 +39,19 @@ def get_model_fn(create_fs_params, cfg, access_and_evict_config_dict=None): loss_dict = {} if mode == tf.estimator.ModeKeys.TRAIN: - logger.info(f"use estimator train mode") + logger.info("Use estimator train mode") loss_dict['loss'] = [['train_loss', loss]] return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=get_train_op(params, loss_dict.get('loss'))) if mode == tf.estimator.ModeKeys.EVAL: - logger.info("use estimator eval mode") + logger.info("Use estimator eval mode") return tf.estimator.EstimatorSpec(mode=mode, loss=loss) if mode == tf.estimator.ModeKeys.PREDICT: - logger.info("use estimator predict mode") + logger.info("Use estimator predict mode") loss_dict['task_1'] = prediction[0] loss_dict['task_2'] = prediction[1] diff --git a/examples/demo/little_demo_estimator/nn_optim.py b/examples/demo/little_demo_estimator/nn_optim.py index 4438627da6db9bb5b664425b9059cb9a8b3c5dcc..d07556a60e7aea4d449ccff83bd3033f46b0377a 100644 --- a/examples/demo/little_demo_estimator/nn_optim.py +++ b/examples/demo/little_demo_estimator/nn_optim.py @@ -28,18 +28,6 @@ from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_ from mx_rec.util.log import logger -def get_dense_and_sparse_optimizer(cfg): - dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) - if ConfigInitializer.get_instance().use_dynamic_expansion: - sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=cfg.learning_rate) - logger.info("optimizer create_hash_optimizer_by_addr") - else: - sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate) - logger.info("optimizer create_hash_optimizer") - - return dense_optimizer, sparse_optimizer - - def get_train_op_list(losses, learning_rate): train_ops_list = [] update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) @@ -55,9 +43,7 @@ def get_train_op_list(losses, learning_rate): dense_variables, sparse_variables = get_dense_and_sparse_variable() trainable_variables = [dense_variables] - for i in range(len(losses)): - name = losses[i][0] - loss = losses[i][1] + for _, (name, loss) in enumerate(losses): with tf.control_dependencies(update_ops): # do dense grad grads = dense_optimizer.compute_gradients(loss, var_list=trainable_variables) @@ -73,11 +59,11 @@ def get_train_op_list(losses, learning_rate): # do sparse optimization if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) local_grads = tf.gradients(loss, train_emb_list) # local_embedding grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] diff --git a/examples/demo/little_demo_estimator/run.sh b/examples/demo/little_demo_estimator/run.sh index 33770e59b5c0029b44a2189686eacc6f46cc16c9..011f0001d7af2a9006c92332b4c75366de9eb79e 100644 --- a/examples/demo/little_demo_estimator/run.sh +++ b/examples/demo/little_demo_estimator/run.sh @@ -15,9 +15,6 @@ # ============================================================================== kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 -rm -rf /root/ascend/log/* -rm -rf ./kernel* -rm -rf ./export_graph/* # 获取输入参数:py、ip if [ $# -ge 1 ]; then @@ -83,17 +80,10 @@ export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL # 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL export MXREC_MODE="ASC" -export USE_MPI=1 -export USE_MODE="train_and_evaluate" # 支持[train, predict, train_and_evaluate] - -if [ $USE_MODE = "train" ] || [ $USE_MODE = "train_and_evaluate" ];then - echo "train mode: saved-model will be deleted" - rm -rf ./_rank* -fi +export USE_MODE="train_and_evaluate" # 支持[train, predict, train_and_evaluate],train相关模式将删除./_rank*目录 ################# 参数配置 ###################### export USE_DYNAMIC=1 # 0:静态shape;1:动态shape -export USE_HOT=1 # 0:关闭hot emb;1: 开启hot emb export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export MULTI_LOOKUP_TIMES=2 # 一表多查次数:默认2,上限127(因为一表已经有一查);仅当export USE_MULTI_LOOKUP=1时生效 @@ -106,8 +96,7 @@ export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 export FAST_UNIQUE=0 #if use fast unique export MGMT_HBM_TASK_MODE=0 #if async h2d (get and send tensors) ################## 测试配置项 ##################### -# NOTE: 仅在测试constant、string相关op作为稀疏表输入时启用,当前版本只支持TF1。 -export ENABLE_PUSH_OPS_TEST=0 +export ENABLE_SLICER_TEST=0 # 帮助信息,不需要修改 if [[ $1 == --help || $1 == -h ]];then @@ -152,7 +141,6 @@ else echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" echo "CM_WORKER_IP=$CM_WORKER_IP" echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" - echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" ######################################################### else echo "ip: $ip not available!" # 使用ranktable方案 @@ -169,4 +157,4 @@ DATE=$(date +%Y-%m-%d-%H-%M-%S) horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ python3.7 ${py} \ --run_mode=$USE_MODE \ -2>&1 | tee "temp_${local_rank_size}p_${KEY_PROCESS_THREAD_NUM}t_${DATE}.log" +2>&1 | tee "temp_${num_process}p_${KEY_PROCESS_THREAD_NUM}t_${DATE}.log" diff --git a/examples/dlrm/README.md b/examples/dlrm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..85293c0c38985f3b4c956650ddb99e1ae0749994 --- /dev/null +++ b/examples/dlrm/README.md @@ -0,0 +1,60 @@ +# DLRM模型运行说明 + +## 代码结构 +```shell +. +├── criteo_tb +│ ├── gen_ttf.py # criteo_tb原始数据转换成tfrecord格式的脚本 +│ └── README.md # 数据格式转换脚本说明 +├── model +│ ├── config.py # 模型配置文件 +│ ├── delay_loss_scale.py # loss缩放函数 +│ ├── gradient_descent_w.py # 自定义SGD优化器 +│ ├── main_mxrec.py # 主函数 +│ ├── mean_auc.py # 计算acu的脚本 +│ ├── model.py # DLRM模型 +│ ├── op_impl_mode.ini # 算子执行模式配置 +│ ├── optimizer.py # 优化器 +│ └── run.sh # 运行DLRM模型的脚本 +└── README.md # DLRM模型运行说明 +``` + +## 1.准备数据 +参考criteo_tb目录下的说明文档准备好模型所需要的数据集,放在一个目录下,比如:/data/criteo_tb/。 + +## 2.准备运行环境 +运行环境可以参考[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”章节进行准备。 + +## 3.安装mxRec +mxRec软件包可以通过[mxRec用户指南](https://www.hiascend.com/document/detail/zh/mind-sdk/60rc1/mxRec/mxrecug/mxrecug_0007.html) +“安装部署”>“环境准备”>“获取软件包”章节提供的链接进行下载,选择自己需要的架构(x86或者arm)的mxRec包。下载完成之后,将mxRec包解压,进入解压后的目录(mindxsdk-mxrec) +如下: +```shell +. +├── cust_op +│ └── cust_op_by_addr +├── examples +│ ├── DCNv2 +│ ├── demo +│ └── dlrm +├── tf1_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +├── tf2_whl +│ └── mx_rec-{version}-py3-none-linux_x86_64.whl # version为版本号 +└── version.info +``` +其中,tf1_whl和tf2_whl目录下分别是适配tf1和tf2的mxRec软件包,按照自己需要选择其中一个进行安装即可(用pip/pip3 install 软件包这种方式进行安装)。 +确认安装mxRec的目录,比如mxRec安装在 /usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec这个目录下。 + +## 4.运行DLRM模型 +执行完以上步骤之后,接下来就可以运行DLRM模型,其中run.sh就是运行的脚本,默认是8张卡。其中需要传入5个参数,分别对应:so_path、mx_rec_package_path、hccl_cfg_json、 +dlrm_criteo_data_path和ip。运行命令如: +```shell +bash run.sh {so_path} {mx_rec_package_path} {hccl_cfg_json} {dlrm_criteo_data_path} {ip} +``` +* so_path:so_path是mxRec中动态库的目录,一般在mxRec的安装目录下的libasc目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec/libasc。 +* mx_rec_package_path:mx_rec_package_path是mxRec的安装目录,比如:/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec。 +* hccl_cfg_json:hccl_cfg_json是hccl通信配置文件,如果配置了ip参数,这个参数就不用了,直接给一个""空字符串即可。 +* dlrm_criteo_data_path:dlrm_criteo_data_path是数据集所在的目录,比如/data/criteo_tb/。 +* ip:ip是运行模型的机器所在的ip,建议配置。 diff --git a/examples/dlrm/criteo_tb/gen_ttf.py b/examples/dlrm/criteo_tb/gen_ttf.py index 92fabb3d2deb09cd50fb09083bb05feeaf31fc23..986bc6df4019846c9c440a8edce8c5b76529af6c 100644 --- a/examples/dlrm/criteo_tb/gen_ttf.py +++ b/examples/dlrm/criteo_tb/gen_ttf.py @@ -19,12 +19,12 @@ import collections import logging import argparse from multiprocessing import Process -import numpy as np +import sys import time +import numpy as np from tqdm import tqdm from glob import glob from collections import Counter, OrderedDict -import sys import tensorflow as tf @@ -50,11 +50,11 @@ class Logger(object): self.logger.addHandler(sh) # 把对象加到logger里 self.logger.addHandler(th) - def info(self, *args): - if len(args) == 1: - self.logger.info(*args) + def info(self, *prams): + if len(prams) == 1: + self.logger.info(*prams) else: - self.logger.info([*args]) + self.logger.info([*prams]) class CriteoStatsDict(): @@ -89,12 +89,11 @@ class CriteoStatsDict(): for i, cat in enumerate(cat_list): map_cat_count(i, cat) - # - def save_dict(self, output_path, hist_map, prefix=""): - with open(os.path.join(output_path, "{}hist_map.pkl".format(prefix)), "wb") as file_wrt: + @staticmethod + def save_dict(output_file_path, hist_map, prefix=""): + with os.fdopen(os.path.join(output_file_path, "{}hist_map.pkl".format(prefix)), "wb") as file_wrt: pickle.dump(hist_map, file_wrt) - # def load_dict(self, dict_path, prefix=""): with open(os.path.join(dict_path, "{}hist_map.pkl".format(prefix)), "rb") as file_wrt: self.hist_map = pickle.load(file_wrt) @@ -128,13 +127,14 @@ class CriteoStatsDict(): return dense_list, cat_list -def statsdata_multiprocess(process_num, process_id, data_file_path, output_path, criteo_stats): + +def statsdata_multiprocess(proc_num, proc_id, data_file_path, output_file_path, criteo_stats_data): start_time = time.time() with open(data_file_path, encoding="utf-8") as file_in: errorline_list = [] count = 0 for i, line in enumerate(file_in): - if i % process_num != process_id: + if i % proc_num != proc_id: continue count += 1 line = line.strip("\n") @@ -146,26 +146,26 @@ def statsdata_multiprocess(process_num, process_id, data_file_path, output_path, if count % 1000000 == 0: print("Have handle {}w lines.".format(count // 10000)) cats = items[14:] - criteo_stats.stats_cats(cats) - criteo_stats.save_dict(output_path) + criteo_stats_data.stats_cats(cats) + criteo_stats_data.save_dict(output_file_path) print('statsdata time cost: {:.2f}s'.format(time.time() - start_time)) -def get_unique_id_multiprocess(process_num, process_id, data_file_path, output_path, criteo_stats): - if os.path.exists(os.path.join(output_path, "unique_id.pkl")): +def get_unique_id_multiprocess(proc_num, proc_id, data_file_path, output_file_path, criteo_stats_data): + if os.path.exists(os.path.join(output_file_path, "unique_id.pkl")): return start_time = time.time() - cat_sets = [OrderedDict() for col in criteo_stats.cat_cols] - cat_global_id_nums = [0 for col in criteo_stats.cat_cols] - hash_bucket = criteo_stats.hash_bucket + cat_sets = [OrderedDict() for col in criteo_stats_data.cat_cols] + cat_global_id_nums = [0 for col in criteo_stats_data.cat_cols] + hash_bucket = criteo_stats_data.hash_bucket line_num = 0 with open(data_file_path, encoding="utf-8") as file_in: errorline_list = [] for i, line in enumerate(file_in): line_num += 1 - start_line = process_id * ((line_num + process_num) // process_num) - end_line = (process_id + 1) * ((line_num + process_num) // process_num) + start_line = proc_id * ((line_num + proc_num) // proc_num) + end_line = (proc_id + 1) * ((line_num + proc_num) // proc_num) with open(data_file_path, encoding="utf-8") as file_in: errorline_list = [] count = 0 @@ -183,21 +183,17 @@ def get_unique_id_multiprocess(process_num, process_id, data_file_path, output_p print("Have handle {}w lines.".format(count // 10000)) sys.stdout.flush() cats = items[14:] - # criteo_stats.stats_cats(cats) - # def map_cat_count(i, cat): for k, cat in enumerate(cats): - # map_cat_count(i, cat) capped_value = int(cat, 16) % hash_bucket if cat else hash_bucket - # if capped_value not in self.hist_map[key_col]: if capped_value not in cat_sets: cat_sets[k][capped_value] = cat_global_id_nums[k] cat_global_id_nums[k] += 1 - with open(os.path.join(output_path, "unique_id.pkl"), "wb") as file_wrt: + with os.fdopen(os.path.join(output_file_path, "unique_id.pkl"), "wb") as file_wrt: pickle.dump(cat_sets, file_wrt) print('statsdata time cost: {:.2f}s'.format(time.time() - start_time)) -def merge_stats_count(stats_dir, criteo_stats): +def merge_stats_count(stats_dir, criteo_stats_data): if os.path.exists(f'{stats_dir}/hist_map.pkl'): return stats_sub_dirs = sorted(glob(f'{stats_dir}/*[0-9]')) @@ -207,15 +203,15 @@ def merge_stats_count(stats_dir, criteo_stats): for i in tqdm(range(1, len(stats_sub_dirs))): with open(f'{stats_sub_dirs[i]}/unique_id.pkl', 'rb') as f: others_count = pickle.load(f) - for k, _ in enumerate(criteo_stats.cat_cols): + for k, _ in enumerate(criteo_stats_data.cat_cols): all_count_1, others_count_1 = all_hist_map[k], others_count[k] all_count_1.update(others_count_1) all_hist_map[k] = all_count_1 hist_map = {} - for i, col in enumerate(criteo_stats.cat_cols): + for i, col in enumerate(criteo_stats_data.cat_cols): hist_map[col] = dict(zip(list(all_hist_map[i].keys()), range(len(all_hist_map[i])))) - criteo_stats.save_dict(stats_dir, hist_map) + criteo_stats_data.save_dict(stats_dir, hist_map) def mkdir_path(file_path): @@ -228,20 +224,21 @@ def make_example(label_list, dense_feat_list, sparse_feat_list): sparse_feature = np.array(sparse_feat_list, dtype=np.int64).reshape(-1) label = np.array(label_list, dtype=np.int64).reshape(-1) feature_dict = {"dense_feature": tf.train.Feature(float_list=tf.train.FloatList(value=dense_feature)), - "sparse_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=sparse_feature)), - "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label)) - } + "sparse_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=sparse_feature)), + "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label)) + } example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) return example -def convert_input2tfrd_multiprocess(process_num, process_id, in_file_path, output_path, criteo_stats, line_per_sample=1024, - part_rows=2000000, mode="train_"): + +def convert_input2tfrd_multiprocess(proc_num, proc_id, in_file_path, output_file_path, criteo_stats_dict, + line_per_sample=1024, part_rows=2000000): start_time = time.time() print("----------" * 10 + "\n" * 2) part_number = 0 - file_name = output_path + "part_{:0>8d}.tfrecord" + file_name = output_file_path + "part_{:0>8d}.tfrecord" file_writer = tf.python_io.TFRecordWriter(file_name.format(part_number)) sample_count = 0 @@ -250,11 +247,11 @@ def convert_input2tfrd_multiprocess(process_num, process_id, in_file_path, outpu with open(in_file_path, encoding="utf-8") as file_in: errorline_list = [] - for i, line in tqdm(enumerate(file_in)): + for _ in tqdm(file_in): line_num += 1 print(f'line_num: {line_num}') - start_line = process_id * ((line_num + process_num) // process_num) - end_line = (process_id + 1) * ((line_num + process_num) // process_num) + start_line = proc_id * ((line_num + proc_num) // proc_num) + end_line = (proc_id + 1) * ((line_num + proc_num) // proc_num) dense_res_list = [] cat_res_list = [] label_res_list = [] @@ -276,9 +273,11 @@ def convert_input2tfrd_multiprocess(process_num, process_id, in_file_path, outpu label = int(items[0]) values = items[1:14] cats = items[14:] - assert len(values) == 13, "values.size: {}".format(len(values)) - assert len(cats) == 26, "cats.size: {}".format(len(cats)) - val_list, cat_list = criteo_stats.map_cat2id(values, cats) + if len(values) != 13: + raise ValueError("dense feature length must be 13, current values.size: {}".format(len(values))) + if len(cats) != 26: + raise ValueError("sparse feature length must be 26, current cats.size: {}".format(len(cats))) + val_list, cat_list = criteo_stats_dict.map_cat2id(values, cats) dense_res_list.append(val_list) cat_res_list.append(cat_list) label_res_list.append(label) @@ -362,16 +361,18 @@ if __name__ == "__main__": mkdir_path(save_tfrecord_path) processs = [] process_num = args.train_process_num - assert process_num % len(train_data_files) == 0, print( - f'process_num {process_num} must exact div length of data_files {len(data_files)}') + if len(train_data_files) == 0: + raise ValueError(f'file not exist in train_data_dir:{train_data_dir}') + if process_num % len(train_data_files) != 0: + raise ValueError(f'process_num {process_num} must exact div length of train_data_files {len(train_data_files)}') for process_id in range(process_num): sub_process_num = process_num // len(train_data_files) data_file = train_data_files[process_id // sub_process_num] output_path = f'{save_tfrecord_path}/{process_id:04}_' - p = Process(target=convert_input2tfrd_multiprocess, args=(sub_process_num, process_id%sub_process_num, data_file, output_path, - criteo_stats, spe_num, - 5000000)) + p = Process(target=convert_input2tfrd_multiprocess, args=(sub_process_num, process_id % sub_process_num, + data_file, output_path, criteo_stats, spe_num, + 5000000)) processs.append(p) for p in processs: p.start() @@ -384,17 +385,18 @@ if __name__ == "__main__": mkdir_path(save_tfrecord_path) processs = [] process_num = args.test_process_num - assert process_num % len(test_data_files) == 0, print( - f'process_num {process_num} must exact div length of data_files {len(data_files)}') + if len(test_data_files) == 0: + raise ValueError(f'file not exist in test_data_dir:{test_data_dir}') + if process_num % len(test_data_files) != 0: + raise ValueError(f'process_num {process_num} must exact div length of test_data_files {len(test_data_files)}') for process_id in range(process_num): sub_process_num = process_num // len(test_data_files) data_file = test_data_files[process_id // sub_process_num] output_path = f'{save_tfrecord_path}/{process_id:04}_' - p = Process(target=convert_input2tfrd_multiprocess, args=(sub_process_num, process_id%sub_process_num, data_file, output_path, - criteo_stats, spe_num, - 5000000)) - + p = Process(target=convert_input2tfrd_multiprocess, args=(sub_process_num, process_id % sub_process_num, + data_file, output_path, criteo_stats, spe_num, + 5000000)) processs.append(p) for p in processs: p.start() diff --git a/examples/dlrm/model/config.py b/examples/dlrm/model/config.py index 452b2a7f5e6c1a877f37eb8cd395f7798227856e..c30a22d40d96443623cc95cf020d891bfab0c9ca 100644 --- a/examples/dlrm/model/config.py +++ b/examples/dlrm/model/config.py @@ -14,12 +14,17 @@ # limitations under the License. # ============================================================================== +import enum import os import tensorflow as tf from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig from npu_bridge.estimator.npu.npu_config import NPURunConfig +from mx_rec.constants.constants import CacheModeEnum + +SSD_DATA_PATH = ["ssd_data"] + class LearningRateScheduler: """ @@ -40,7 +45,6 @@ class LearningRateScheduler: # used for the warmup stage warmup_step = tf.cast(1 / self.warmup_steps, tf.float32) lr_factor_warmup = 1 - tf.cast(self.warmup_steps - global_step, tf.float32) * warmup_step - # lr_factor_warmup = tf.cast(global_step, tf.float32) / tf.cast(self.warmup_steps, tf.float32) #hx lr_factor_warmup = tf.cast(lr_factor_warmup, tf.float32) # used for the constant stage lr_factor_constant = tf.cast(1.0, tf.float32) @@ -55,7 +59,6 @@ class LearningRateScheduler: global_step < self.decay_end_step, lambda: lr_factor_decay, lambda: sparse_after_decay, - # lambda: 0.000 #hx ) lr_factor_decay_dense = tf.cond( @@ -91,10 +94,10 @@ class LearningRateScheduler: class Config: def __init__(self, ): - self.rank_id = int(os.getenv("RANK_ID")) if os.getenv("RANK_ID") else None - tmp = os.getenv("RANK_SIZE") + self.rank_id = int(os.getenv("OMPI_COMM_WORLD_RANK")) if os.getenv("OMPI_COMM_WORLD_RANK") else None + tmp = os.getenv("TRAIN_RANK_SIZE") if tmp is None: - raise ValueError("please export RANK_SIZE") + raise ValueError("please export TRAIN_RANK_SIZE") self.rank_size = int(tmp) self.data_path = os.getenv("DLRM_CRITEO_DATA_PATH") @@ -119,9 +122,10 @@ class Config: self.emb_dim = 128 self.hashtable_threshold = 1 - # self.learning_rate = 0.01 self.USE_PIPELINE_TEST = False + # False indicates use SGD optimizer, else use LazyAdam. If True, is incompatible with dynamic_expansion + self.use_lazy_adam_optimizer = False # 动态学习率 GLOBAL_BATCH_SIZE = 8192 * 8 @@ -145,30 +149,30 @@ class Config: if self.cache_mode is None: raise ValueError("please export CACHE_MODE environment variable, support:[HBM, DDR, SSD]") - if self.cache_mode == "HBM": + if self.cache_mode == CacheModeEnum.HBM.value: self.dev_vocab_size = 24_000_000 * self.rank_size self.host_vocab_size = 0 - elif self.cache_mode == "DDR": + elif self.cache_mode == CacheModeEnum.DDR.value: self.dev_vocab_size = 500_000 * self.rank_size self.host_vocab_size = 24_000_000 * self.rank_size - elif self.cache_mode == "SSD": + elif self.cache_mode == CacheModeEnum.SSD.value: self.dev_vocab_size = 100_000 * self.rank_size self.host_vocab_size = 2_000_000 * self.rank_size self.ssd_vocab_size = 24_000_000 * self.rank_size else: raise ValueError(f"get CACHE_MODE:{self.cache_mode}, expect in [HBM, DDR, SSD]") - def get_emb_table_cfg(self) -> dict: - if self.cache_mode == "HBM": + def get_emb_table_cfg(self): + if self.cache_mode == CacheModeEnum.HBM.value: return {"device_vocabulary_size": self.dev_vocab_size} - elif self.cache_mode == "DDR": + elif self.cache_mode == CacheModeEnum.DDR.value: return {"device_vocabulary_size": self.dev_vocab_size, "host_vocabulary_size": self.host_vocab_size} - elif self.cache_mode == "SSD": + elif self.cache_mode == CacheModeEnum.SSD.value: return {"device_vocabulary_size": self.dev_vocab_size, "host_vocabulary_size": self.host_vocab_size, "ssd_vocabulary_size": self.ssd_vocab_size, - "ssd_data_path": ["ssd_data"]} + "ssd_data_path": SSD_DATA_PATH} else: raise RuntimeError(f"get CACHE_MODE:{self.cache_mode}, check Config.__set_emb_table_size implementation") @@ -182,8 +186,8 @@ def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): custom_op.parameter_map["mix_compile_mode"].b = False custom_op.parameter_map["use_off_line"].b = True custom_op.parameter_map["min_group_size"].b = 1 + # 可选配置level0:pairwise;level1:pairwise custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:fullmesh;level1:fullmesh") - # custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:pairwise;level1:pairwise") custom_op.parameter_map["enable_data_pre_proc"].b = True custom_op.parameter_map["iterations_per_loop"].i = 10 custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") @@ -228,7 +232,6 @@ def get_npu_run_config(): iterations_per_loop=1, jit_compile=False, op_compiler_cache_mode="enable", - HCCL_algorithm="level0:fullmesh;level1:fullmesh" - # HCCL_algorithm="level0:pairwise;level1:pairwise" + HCCL_algorithm="level0:fullmesh;level1:fullmesh" # 可选配置:level0:pairwise;level1:pairwise ) return run_config diff --git a/examples/dlrm/model/delay_loss_scale.py b/examples/dlrm/model/delay_loss_scale.py index 0cb5068858141c7f2b8df05897c88ddbe1060d3a..01bb0d8f0f2a2b6b65a1bdcb29e3f00c137ba85a 100644 --- a/examples/dlrm/model/delay_loss_scale.py +++ b/examples/dlrm/model/delay_loss_scale.py @@ -17,32 +17,48 @@ import tensorflow as tf from tensorflow.python.training import optimizer +from config import Config + class DenseLossScaleOptimizer: - def __init__(self, opt, loss_scale): + def __init__(self, opt: optimizer.Optimizer, loss_scale: int) -> None: if not isinstance(opt, optimizer.Optimizer): raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) self._optimizer = opt self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) - self._optimizer._learning_rate = self._optimizer._learning_rate / self._loss_scale + _update_lr_loss_scale(self._optimizer, loss_scale) def compute_gradients(self, loss, var_list=None): - return self._optimizer.compute_gradients(loss*self._loss_scale, var_list=var_list) + return self._optimizer.compute_gradients(loss * self._loss_scale, var_list=var_list) def apply_gradients(self, avg_grads): return self._optimizer.apply_gradients(avg_grads) class SparseLossScaleOptimizer: - def __init__(self, opt, loss_scale): + def __init__(self, opt: optimizer.Optimizer, loss_scale: int) -> None: if not isinstance(opt, optimizer.Optimizer): raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) self._optimizer = opt self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) - self._optimizer._learning_rate = self._optimizer._learning_rate / self._loss_scale + _update_lr_loss_scale(self._optimizer, loss_scale) def compute_gradients(self, loss, var_list=None): - return tf.gradients(loss*self._loss_scale, var_list) + return tf.gradients(loss * self._loss_scale, var_list) def apply_gradients(self, grads_and_vars): - return self._optimizer.apply_gradients(grads_and_vars) \ No newline at end of file + return self._optimizer.apply_gradients(grads_and_vars) + + +def _update_lr_loss_scale(opt, loss_scale): + if loss_scale <= 0: + raise RuntimeError("the loss_scale must be greater than zero.") + loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) + if hasattr(opt, "_lr"): + # LazyAdam or Adam optimizer + opt._lr = opt._lr / loss_scale + elif hasattr(opt, "_learning_rate"): + # SGD optimizer + opt._learning_rate = opt._learning_rate / loss_scale + else: + raise RuntimeError("`opt` should have a `_learning_rate` or `_lr` named field.") diff --git a/examples/dlrm/model/gradient_descent_w.py b/examples/dlrm/model/gradient_descent_w.py index f3ae78d77355237614cab90d6b25a36514bf8728..53adb996bb20424fc91a3722dee7270a45465ddc 100644 --- a/examples/dlrm/model/gradient_descent_w.py +++ b/examples/dlrm/model/gradient_descent_w.py @@ -47,14 +47,8 @@ class CustomizedGradientDescentWithWeighDecay(gradient_descent.GradientDescentOp super(CustomizedGradientDescentWithWeighDecay, self).__init__( learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name ) - - def initialize_slots(self, var, table_instance): - logger.info("no slot for gradient descent") - return [] - - def insert_slot(self, slot, named_slots_key, slot_name): - logger.info("no slot for gradient descent") - return dict() + self._slot_num = 0 + self._derivative = 1 def get_slot_init_values(self): logger.info("no slot for gradient descent") diff --git a/examples/dlrm/model/main_mxrec.py b/examples/dlrm/model/main_mxrec.py index dd3e8d2d7d31e8dde2494890c963de5accd525f7..51ed7c4acd66f102a0fa0c782da735da8c910497 100644 --- a/examples/dlrm/model/main_mxrec.py +++ b/examples/dlrm/model/main_mxrec.py @@ -15,6 +15,7 @@ # ============================================================================== import os +import shutil import time import warnings import random @@ -24,6 +25,10 @@ import tensorflow as tf from sklearn.metrics import roc_auc_score import numpy as np +from optimizer import get_dense_and_sparse_optimizer +from config import sess_config, Config, SSD_DATA_PATH, CacheModeEnum +from model import MyModel +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup @@ -37,10 +42,6 @@ from mx_rec.util.variable import get_dense_and_sparse_variable from mx_rec.util.log import logger from npu_bridge.npu_init import * -from model import MyModel -from config import sess_config, Config -from optimizer import get_dense_and_sparse_optimizer - npu_plugin.set_device_sat_mode(0) dense_hashtable_seed = 128 @@ -56,8 +57,8 @@ def add_timestamp_func(batch): return batch -def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use_faae=False): - if cfg.USE_PIPELINE_TEST: +def make_batch_and_iterator(config, feature_spec_list, is_training, dump_graph, is_use_faae=False): + if config.USE_PIPELINE_TEST: num_parallel = 1 else: num_parallel = 8 @@ -65,9 +66,9 @@ def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use def extract_fn(data_record): features = { # Extract features using the keys set during creation - 'label': tf.compat.v1.FixedLenFeature(shape=(cfg.line_per_sample,), dtype=tf.int64), - 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(26 * cfg.line_per_sample,), dtype=tf.int64), - 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(13 * cfg.line_per_sample,), dtype=tf.float32), + 'label': tf.compat.v1.FixedLenFeature(shape=(config.line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(26 * config.line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(13 * config.line_per_sample,), dtype=tf.float32), } sample = tf.compat.v1.parse_single_example(data_record, features) return sample @@ -80,24 +81,23 @@ def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use return batch if is_training: - files_list = glob(os.path.join(cfg.data_path, cfg.train_file_pattern) + '/*.tfrecord') + files_list = glob(os.path.join(config.data_path, config.train_file_pattern) + '/*.tfrecord') else: - files_list = glob(os.path.join(cfg.data_path, cfg.test_file_pattern) + '/*.tfrecord') + files_list = glob(os.path.join(config.data_path, config.test_file_pattern) + '/*.tfrecord') dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel) - batch_size = cfg.batch_size // cfg.line_per_sample + batch_size = config.batch_size // config.line_per_sample - dataset = dataset.shard(cfg.rank_size, cfg.rank_id) + dataset = dataset.shard(config.rank_size, config.rank_id) if is_training: dataset = dataset.shuffle(batch_size * 1000, seed=shuffle_seed) if is_training: - dataset = dataset.repeat(cfg.train_epoch) + dataset = dataset.repeat(config.train_epoch) else: - dataset = dataset.repeat(cfg.test_epoch) - # dataset = dataset.repeat(cfg.num_epochs) + dataset = dataset.repeat(config.test_epoch) dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel).batch(batch_size, drop_remainder=True) dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) - if use_faae: + if is_use_faae: dataset = dataset.map(add_timestamp_func) if not MODIFY_GRAPH_FLAG: @@ -128,7 +128,7 @@ def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph): elif len(embedding_list) > 1: emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False) else: - raise ValueError("The length of embedding_list must be greater than or equal to 1.") + raise ValueError("the length of embedding_list must be greater than or equal to 1.") my_model = MyModel() model_output = my_model.build_model(embedding=emb, dense_feature=batch["dense_feature"], @@ -158,13 +158,13 @@ def evaluate(): try: eval_current_steps += 1 eval_start = time.time() - eval_loss, pred, label = sess.run([eval_model["loss"], eval_model["pred"], eval_label]) + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_label]) eval_cost = time.time() - eval_start - qps = (1 / eval_cost) * rank_size * cfg.batch_size + qps_eval = (1 / eval_cost) * rank_size * cfg.batch_size log_loss_list += list(eval_loss.reshape(-1)) pred_list += list(pred.reshape(-1)) label_list += list(label.reshape(-1)) - print(f"eval current_steps: {eval_current_steps}, qps: {qps}") + print(f"eval current_steps: {eval_current_steps}, qps: {qps_eval}") if eval_current_steps == eval_steps: finished = True except tf.errors.OutOfRangeError: @@ -189,7 +189,7 @@ def evaluate_fix(step): while not finished: try: eval_current_steps += 1 - eval_loss, pred, label = sess.run([eval_model["loss"], eval_model["pred"], eval_model["label"]]) + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_model.get("label")]) log_loss_list += list(eval_loss.reshape(-1)) pred_list += list(pred.reshape(-1)) label_list += list(label.reshape(-1)) @@ -216,7 +216,6 @@ def evaluate_fix(step): os.mknod(f"flag_{rank_id}.txt") while True: file_exists_list = [os.path.exists(f"flag_{i}.txt") for i in range(rank_size)] - # print(file_exists_list) if sum(file_exists_list) == rank_size: print("All saved!!!!!!!!!!") break @@ -249,12 +248,38 @@ def create_feature_spec_list(use_timestamp=False): return feature_spec_list +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"Delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + _del_related_dir("kernel*") + _del_related_dir("model_dir_rank*") + _del_related_dir("op_cache") + + if os.getenv("CACHE_MODE", "") != CacheModeEnum.SSD.value: + return + logger.info("Current cache mode is SSD, and file overwrite is not allowed in SSD mode, deleting exist directory" + " then create empty directory for this use case.") + for sub_path in SSD_DATA_PATH: + _del_related_dir(sub_path) + os.makedirs(sub_path, mode=0o550, exist_ok=True) + logger.info(f"Create dir:{sub_path}") + + if __name__ == "__main__": tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings("ignore") + _clear_saved_model() rank_id = int(os.getenv("RANK_ID")) if os.getenv("RANK_ID") else None - rank_size = int(os.getenv("RANK_SIZE")) if os.getenv("RANK_SIZE") else None + rank_size = int(os.getenv("TRAIN_RANK_SIZE")) if os.getenv("TRAIN_RANK_SIZE") else None interval = int(os.getenv("INTERVAL")) if os.getenv("INTERVAL") else None train_steps = 10000 eval_steps = 1360 @@ -265,8 +290,8 @@ if __name__ == "__main__": MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) use_faae = bool(int(os.getenv("USE_FAAE", 0))) except ValueError as err: - raise ValueError(f"please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " - f"or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err + raise ValueError("please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " + "or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) logger.info(f"USE_DYNAMIC:{use_dynamic}") @@ -290,16 +315,15 @@ if __name__ == "__main__": feature_spec_list_eval = create_feature_spec_list(use_timestamp=False) train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True, - dump_graph=True, use_faae=use_faae) + dump_graph=True, is_use_faae=use_faae) eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False, - dump_graph=False, use_faae=use_faae) + dump_graph=False, is_use_faae=use_faae) logger.info(f"train_batch: {train_batch}") if use_faae: cfg.dev_vocab_size = cfg.dev_vocab_size // 2 optimizer_list = [get_dense_and_sparse_optimizer(cfg)] - sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] # note: variance_scaling_initializer only support HBM mode emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed) \ @@ -310,7 +334,6 @@ if __name__ == "__main__": dim=tf.TensorShape([cfg.emb_dim]), name="sparse_embeddings", emb_initializer=emb_initializer, - optimizer_list=[sparse_optimizer_list[0]._optimizer], **cfg.get_emb_table_cfg() ) if use_faae: @@ -323,15 +346,20 @@ if __name__ == "__main__": is_train=False, modify_graph=MODIFY_GRAPH_FLAG) dense_variables, sparse_variables = get_dense_and_sparse_variable() - + trainable_varibles = [] + trainable_varibles.extend(dense_variables) + if use_dynamic_expansion: + trainable_varibles.append(tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB)[0]) + else: + trainable_varibles.extend(sparse_variables) rank_size = mxrec_util.communication.hccl_ops.get_rank_size() train_ops = [] # multi task training - for loss, (dense_optimizer, sparse_optimizer) in zip([train_model["loss"]], optimizer_list): + for loss, (dense_optimizer, sparse_optimizer) in zip([train_model.get("loss")], optimizer_list): # do dense optimization - grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) + grads = dense_optimizer.compute_gradients(loss, var_list=trainable_varibles) avg_grads = [] - for grad, var in grads: + for grad, var in grads[:-1]: if rank_size > 1: grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None if grad is not None: @@ -340,17 +368,14 @@ if __name__ == "__main__": train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS - - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) - train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) # do sparse optimization by addr - sparse_grads = sparse_optimizer.compute_gradients(loss, train_emb_list) # local_embedding + sparse_grads = list(grads[-1]) # local_embedding grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)] train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) else: # do sparse optimization - sparse_grads = sparse_optimizer.compute_gradients(loss, sparse_variables) + sparse_grads = list(grads[-1]) print("sparse_grads_tensor:", sparse_grads) grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) @@ -410,7 +435,7 @@ if __name__ == "__main__": start_time = time.time() try: - grad, loss = sess.run([train_ops, train_model["loss"]]) + grad, loss = sess.run([train_ops, train_model.get("loss")]) lr = sess.run(cfg.learning_rate) global_step = sess.run(cfg.global_step) except tf.errors.OutOfRangeError: @@ -421,7 +446,6 @@ if __name__ == "__main__": cost_time = end_time - start_time qps = (1 / cost_time) * rank_size * cfg.batch_size * iteration_per_loop cost_sum += cost_time - # qps_sum += qps logger.info(f"step: {i * iteration_per_loop}; training loss: {loss}") logger.info(f"step: {i * iteration_per_loop}; grad: {grad}") logger.info(f"step: {i * iteration_per_loop}; lr: {lr}") diff --git a/examples/dlrm/model/mean_auc.py b/examples/dlrm/model/mean_auc.py index 1116ebd578d691b63e391e2d53e144801e80fe11..ff57df00e575551456883147a5772a3b16dc638f 100644 --- a/examples/dlrm/model/mean_auc.py +++ b/examples/dlrm/model/mean_auc.py @@ -15,8 +15,8 @@ # ============================================================================== import os -import numpy as np from glob import glob +import numpy as np def split_auc(log_input): @@ -26,7 +26,7 @@ def split_auc(log_input): if 'Test' in line: all_auc.append(float(line.split(';')[0].split(':')[-1].strip())) all_auc_len = len(all_auc) - all_auc_arr = np.array(all_auc)[:all_auc_len - all_auc_len%8] + all_auc_arr = np.array(all_auc)[:all_auc_len - all_auc_len % 8] test_auc = np.mean(all_auc_arr.reshape(-1, 8), axis=-1) return test_auc diff --git a/examples/dlrm/model/optimizer.py b/examples/dlrm/model/optimizer.py index 7a6d687834ca4f8008353d176635d9f0df4db8b3..18dbe28887d062de143f52a5b70da6c861852a1a 100644 --- a/examples/dlrm/model/optimizer.py +++ b/examples/dlrm/model/optimizer.py @@ -15,20 +15,32 @@ # ============================================================================== import tensorflow as tf + from delay_loss_scale import DenseLossScaleOptimizer, SparseLossScaleOptimizer from gradient_descent_w import create_hash_optimizer from mx_rec.util.initialize import ConfigInitializer from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr +from mx_rec.optimizers import lazy_adam def get_dense_and_sparse_optimizer(cfg): - dense_optimizer = tf.train.GradientDescentOptimizer(learning_rate=cfg.learning_rate[0]) use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion - if use_dynamic_expansion: - sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=cfg.learning_rate[1], weight_decay=0.0001) + if cfg.use_lazy_adam_optimizer: + if use_dynamic_expansion: + raise RuntimeError("model is incompatible with dynamic_expansion when use lazy_adam optimizer.") + # use lazy_adam optimizer + dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate[0]) + sparse_optimizer = lazy_adam.create_hash_optimizer(learning_rate=cfg.learning_rate[1]) + loss_scale = 65536 else: - sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1], weight_decay=0.0001) - sparse_optimizer = SparseLossScaleOptimizer(sparse_optimizer, 1024) - dense_optimizer = DenseLossScaleOptimizer(dense_optimizer, 1024) + # use SGD optimizer + dense_optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=cfg.learning_rate[0]) + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=cfg.learning_rate[1], weight_decay=0.0001) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1], weight_decay=0.0001) + loss_scale = 1024 + sparse_optimizer = SparseLossScaleOptimizer(sparse_optimizer, loss_scale) + dense_optimizer = DenseLossScaleOptimizer(dense_optimizer, loss_scale) return dense_optimizer, sparse_optimizer diff --git a/examples/dlrm/model/run.sh b/examples/dlrm/model/run.sh index 919f0f98a09921c9b866c3f59723c84408be6224..6c1424435bac57a727cfc1980e25e08c6c5ba74c 100644 --- a/examples/dlrm/model/run.sh +++ b/examples/dlrm/model/run.sh @@ -20,10 +20,13 @@ so_path=$1 mx_rec_package_path=$2 hccl_cfg_json=$3 dlrm_criteo_data_path=$4 +ip=$5 # no ranktable时传入该参数 -export RANK_SIZE=8 -echo "RANK_SIZE=${RANK_SIZE}, please make sure hccl configuration json file match this parameter" -export RANK_TABLE_FILE=${hccl_cfg_json} +interface="lo" +num_server=1 +local_rank_size=8 +num_process=$((num_server * local_rank_size)) +export TRAIN_RANK_SIZE=$num_process ################# 参数配置 ###################### export USE_DYNAMIC=0 # 0:静态shape;1:动态shape @@ -33,26 +36,13 @@ export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查 export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 ################################################ - echo "CACHE_MODE:${CACHE_MODE}" -if [ ${CACHE_MODE} = "SSD" ]; then - echo "SSD train mode not allow file exist before training, - deleting dir ${cur_path}/ssd_data then create for SSD use case" - rm -rf ssd_data - mkdir ssd_data -fi export HCCL_CONNECT_TIMEOUT=1200 - export DLRM_CRITEO_DATA_PATH=${dlrm_criteo_data_path} export PYTHONPATH=${mx_rec_package_path}:${so_path}:$PYTHONPATH export LD_PRELOAD=/usr/lib64/libgomp.so.1 export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH - -rm -rf kernel* -rm -rf /root/ascend/log/* -rm -rf model_dir_rank* op_cache - export ASCEND_DEVICE_ID=0 export RANK_ID_START=0 export JOB_ID=10086 @@ -75,37 +65,35 @@ RANK_ID_START=0 export MXREC_MODE="ASC" echo "MXREC_MODE is $MXREC_MODE" -export USE_MPI=1 -echo "USE_MPI is $USE_MPI" export py=main_mxrec.py echo "py is $py" - -if [ $USE_MPI -eq 0 ]; then - echo "use for loop to start tasks" - for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); - do - #设置环境变量,不需要修改 - echo "Device ID: $RANK_ID" - export RANK_ID=$RANK_ID - export ASCEND_DEVICE_ID=$RANK_ID - ASCEND_DEVICE_ID=$RANK_ID - if [ -d $cur_path/output/${ASCEND_DEVICE_ID} ];then - rm -rf $cur_path/output/${ASCEND_DEVICE_ID} - mkdir -p $cur_path/output/${ASCEND_DEVICE_ID} - else - mkdir -p $cur_path/output/${ASCEND_DEVICE_ID} - fi - nohup python3 ${py} > $cur_path/output/$ASCEND_DEVICE_ID/test_$ASCEND_DEVICE_ID.log 2>&1 & - done +# 区分ranktable和no ranktable +if [ -n "$ip" ]; then + # no ranktable分支 + echo "Current is no ranktable solution." + echo "Input node ip: $ip, please make sure this ip is available." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=60001 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" else - echo "use horovod to start tasks" - # GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO - mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' - interface="lo" - - horovodrun --network-interface ${interface} -np ${RANK_SIZE} --mpi-args "${mpi_args}" --mpi -H localhost:${RANK_SIZE} \ - python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${RANK_SIZE}p.log + # ranktable分支 + echo "Current is ranktable solution, hccl json file:${hccl_cfg_json}" + export RANK_SIZE=$num_process + echo "RANK_SIZE=${RANK_SIZE}, please make sure hccl configuration json file match this parameter" + export RANK_TABLE_FILE=${hccl_cfg_json} fi +echo "use horovod to start tasks" +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${num_process}p.log diff --git a/examples/mmoe/config.py b/examples/mmoe/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a83582ede79bf7fcd2558e4005226d389d2bbe --- /dev/null +++ b/examples/mmoe/config.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import os + +import tensorflow as tf +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig +from npu_bridge.estimator.npu.npu_config import NPURunConfig + +from mx_rec.constants.constants import CacheModeEnum + +SSD_DATA_PATH = ["ssd_data"] + + +class LearningRateScheduler: + """ + LR Scheduler combining Polynomial Decay with Warmup at the beginning. + TF-based cond operations necessary for performance in graph mode. + """ + + def __init__(self, base_lr_dense, base_lr_sparse): + self.base_lr_dense = base_lr_dense + self.base_lr_sparse = base_lr_sparse + + def calc(self): + # used for the constant stage + lr_factor_constant = tf.cast(1.0, tf.float32) + + lr_sparse = self.base_lr_sparse * lr_factor_constant + lr_dense = self.base_lr_dense * lr_factor_constant + return lr_dense, lr_sparse + + +class Config: + def __init__(self, ) -> None: + self.rank_id = int(os.getenv("OMPI_COMM_WORLD_RANK")) if os.getenv("OMPI_COMM_WORLD_RANK") else None + tmp = os.getenv("TRAIN_RANK_SIZE") + if tmp is None: + raise ValueError("please export TRAIN_RANK_SIZE") + self.rank_size = int(tmp) + + self.data_path = os.getenv("DLRM_CRITEO_DATA_PATH") + self.train_file_pattern = "train" + self.test_file_pattern = "test" + + self.batch_size = 32 + self.line_per_sample = 1 + self.train_epoch = 100 + self.test_epoch = 100 + self.expert_num = 8 + self.gate_num = 2 + self.expert_size = 16 + self.tower_size = 8 + + self.perform_shuffle = False + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.int64 + + self.feat_cnt = 26 + self.__set_emb_table_size() + + self.field_num = 26 + self.send_count = self.get_send_count(self.rank_size) + + self.emb_dim = self.expert_num * self.expert_size + self.gate_num * self.expert_num + self.hashtable_threshold = 1 + + self.USE_PIPELINE_TEST = False + + self.global_step = tf.Variable(0, trainable=False) + _lr_scheduler = LearningRateScheduler( + 0.001, + 0.001 + ) + self.learning_rate = _lr_scheduler.calc() + + + @staticmethod + def get_send_count(self, rank_size): + try: + return 46000 // rank_size + except ZeroDivisionError as exp: + raise ZeroDivisionError('Rank size can not be zero.') from exp + + + def __set_emb_table_size(self) -> None: + self.cache_mode = os.getenv("CACHE_MODE") + if self.cache_mode is None: + raise ValueError("please export CACHE_MODE environment variable, support:[HBM, DDR, SSD]") + + if self.cache_mode == CacheModeEnum.HBM.value: + self.dev_vocab_size = 1000 * self.rank_size + self.host_vocab_size = 0 + elif self.cache_mode == CacheModeEnum.DDR.value: + self.dev_vocab_size = 1000 * self.rank_size + self.host_vocab_size = 1000 * self.rank_size + elif self.cache_mode == CacheModeEnum.SSD.value: + self.dev_vocab_size = 1000 * self.rank_size + self.host_vocab_size = 1000 * self.rank_size + self.ssd_vocab_size = 1000 * self.rank_size + else: + raise ValueError(f"get CACHE_MODE:{self.cache_mode}, expect in [HBM, DDR, SSD]") + + def get_emb_table_cfg(self) -> None: + if self.cache_mode == CacheModeEnum.HBM.value: + return {"device_vocabulary_size": self.dev_vocab_size} + elif self.cache_mode == CacheModeEnum.DDR.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size} + elif self.cache_mode == CacheModeEnum.SSD.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size, + "ssd_vocabulary_size": self.ssd_vocab_size, + "ssd_data_path": SSD_DATA_PATH} + else: + raise RuntimeError(f"get CACHE_MODE:{self.cache_mode}, check Config.__set_emb_table_size implementation") + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + # 可选配置level0:pairwise;level1:pairwise + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:fullmesh;level1:fullmesh") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 10 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + custom_op.parameter_map["op_execute_timeout"].i = 2000 + custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes( + str(13 * 1024 * 1024 * 1024)) # total 31 need 13; + custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(18 * 1024 * 1024 * 1024)) # need 25 + custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("DNN_VM_AICPU:3,AIcoreEngine:3") + + if dump_data: + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + return session_config + + +def get_npu_run_config(): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + run_config = NPURunConfig( + save_summary_steps=1000, + save_checkpoints_steps=100, + keep_checkpoint_max=5, + session_config=session_config, + log_step_count_steps=20, + precision_mode='allow_mix_precision', + enable_data_pre_proc=True, + iterations_per_loop=1, + jit_compile=False, + op_compiler_cache_mode="enable", + HCCL_algorithm="level0:fullmesh;level1:fullmesh" # 可选配置:level0:pairwise;level1:pairwise + ) + return run_config diff --git a/examples/mmoe/main_mxrec.py b/examples/mmoe/main_mxrec.py new file mode 100644 index 0000000000000000000000000000000000000000..d02566aa9339016181cc42ab0672bbddbfb8e9e4 --- /dev/null +++ b/examples/mmoe/main_mxrec.py @@ -0,0 +1,469 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import shutil +import time +import warnings +import random +from glob import glob + +import tensorflow as tf +from sklearn.metrics import roc_auc_score +import numpy as np +from npu_bridge.npu_init import * +from config import sess_config, Config, SSD_DATA_PATH, CacheModeEnum +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET +from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.core.feature_process import EvictHook +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache, GraphModifierHook +from mx_rec.constants.constants import ASCEND_TIMESTAMP +from mx_rec.util.initialize import ConfigInitializer, init, terminate_config_initializer +from mx_rec.util.ops import import_host_pipeline_ops +import mx_rec.util as mxrec_util +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.log import logger +from optimizer import get_dense_and_sparse_optimizer + +from model import MyModel + +npu_plugin.set_device_sat_mode(0) + +dense_hashtable_seed = 128 +sparse_hashtable_seed = 128 +shuffle_seed = 128 +random.seed(shuffle_seed) + + +def add_timestamp_func(batch): + timestamp = import_host_pipeline_ops().return_timestamp(tf.cast(batch['label'], dtype=tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def make_batch_and_iterator(config, feature_spec_list, is_training, dump_graph, is_use_faae=False): + if config.USE_PIPELINE_TEST: + num_parallel = 1 + else: + num_parallel = 8 + + def extract_fn(data_record): + features = { + # Extract features using the keys set during creation + 'label': tf.compat.v1.FixedLenFeature(shape=(2 * config.line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(29 * config.line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(11 * config.line_per_sample,), dtype=tf.float32), + } + sample = tf.compat.v1.parse_single_example(data_record, features) + return sample + + def reshape_fn(batch): + batch['label'] = tf.reshape(batch['label'], [-1, 2]) + batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 11]) + batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 29]) + return batch + + if is_training: + files_list = glob(os.path.join(config.data_path, config.train_file_pattern) + '/*.tfrecord') + else: + files_list = glob(os.path.join(config.data_path, config.test_file_pattern) + '/*.tfrecord') + dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel) + batch_size = config.batch_size // config.line_per_sample + + dataset = dataset.shard(config.rank_size, config.rank_id) + if is_training: + dataset = dataset.shuffle(batch_size * 1000, seed=shuffle_seed) + if is_training: + dataset = dataset.repeat(config.train_epoch) + else: + dataset = dataset.repeat(config.test_epoch) + dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel).batch(batch_size, + drop_remainder=True) + dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) + if is_use_faae: + dataset = dataset.map(add_timestamp_func) + + if not MODIFY_GRAPH_FLAG: + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) + + dataset = dataset.prefetch(100) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator + + +def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph): + embedding_list = [] + logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, " + f"hash_table_list: {len(hash_table_list)}") + for feature, hash_table in zip(feature_list, hash_table_list): + if MODIFY_GRAPH_FLAG: + feature = batch["sparse_feature"] + embedding = sparse_lookup(hash_table, feature, cfg.send_count, dim=None, is_train=is_train, + name="user_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + embedding_list.append(embedding) + + if len(embedding_list) == 1: + emb = embedding_list[0] + elif len(embedding_list) > 1: + emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False) + else: + raise ValueError("the length of embedding_list must be greater than or equal to 1.") + emb = tf.reduce_sum(emb, axis=1) + my_model = MyModel() + model_output = my_model.build_model(embedding=emb, + dense_feature=batch["dense_feature"], + label=batch["label"], + is_training=is_train, + seed=dense_hashtable_seed) + return model_output + + +def evaluate(): + print("read_test dataset") + if not MODIFY_GRAPH_FLAG: + eval_label = eval_model.get("label") + sess.run([eval_iterator.initializer]) + else: + # In sess run mode, if the label from the original batch is still used for sess run, + # a getnext timeout error will occur, and a new batch from the new dataset needs to be used + eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(False).get("label") + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) + log_loss_list = [] + pred_income_list = [] + pred_mat_list = [] + label_income_list = [] + label_mat_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + + while not finished: + + eval_current_steps += 1 + eval_start = time.time() + try: + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_label]) + except tf.errors.OutOfRangeError: + break + eval_cost = time.time() - eval_start + qps_eval = (1 / eval_cost) * rank_size * cfg.batch_size + log_loss_list += list(eval_loss.reshape(-1)) + pred_income = pred[0] + pred_mat = pred[1] + pred_income_list += list(pred_income.reshape(-1)) + pred_mat_list += list(pred_mat.reshape(-1)) + label_income_list += list(label[:, 0].reshape(-1)) + label_mat_list += list(label[:, 1].reshape(-1)) + print(f"eval current_steps: {eval_current_steps}, qps: {qps_eval}") + if eval_current_steps == eval_steps: + finished = True + + auc_income = roc_auc_score(label_income_list, pred_income_list) + auc_mat = roc_auc_score(label_mat_list, pred_mat_list) + mean_log_loss = np.mean(log_loss_list) + return auc_income, auc_mat, mean_log_loss + + +def evaluate_fix(step): + print("read_test dataset evaluate_fix") + if not MODIFY_GRAPH_FLAG: + sess.run([eval_iterator.initializer]) + else: + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) + log_loss_list = [] + pred_list = [] + label_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + while not finished: + try: + eval_current_steps += 1 + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_model.get("label")]) + log_loss_list += list(eval_loss.reshape(-1)) + pred_list += list(pred.reshape(-1)) + label_list += list(label.reshape(-1)) + print(f"eval current_steps: {eval_current_steps}") + + if eval_current_steps == eval_steps: + finished = True + except tf.errors.OutOfRangeError: + finished = True + + label_numpy = np.array(label_list) + pred_numpy = np.array(pred_list) + if not os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}"): + os.makedirs(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}") + + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy") + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy") + if os.path.exists(f"flag_{rank_id}.txt"): + os.remove(f"flag_{rank_id}.txt") + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy", label_numpy) + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy", pred_numpy) + os.mknod(f"flag_{rank_id}.txt") + while True: + file_exists_list = [os.path.exists(f"flag_{i}.txt") for i in range(rank_size)] + if sum(file_exists_list) == rank_size: + print("All saved!!!!!!!!!!") + break + else: + print("Waitting for saving numpy!!!!!!!!") + time.sleep(1) + continue + + auc = roc_auc_score(label_list, pred_list) + mean_log_loss = np.mean(log_loss_list) + return auc, mean_log_loss + + +def create_feature_spec_list(use_timestamp=False): + access_threshold = None + eviction_threshold = None + if use_timestamp: + access_threshold = 1000 + eviction_threshold = 180 + + feature_spec_list = [FeatureSpec("sparse_feature", table_name="sparse_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold)] + if use_multi_lookup: + feature_spec_list.append(FeatureSpec("sparse_feature", table_name="sparse_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"Delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + _del_related_dir("kernel*") + _del_related_dir("model_dir_rank*") + _del_related_dir("op_cache") + + if os.getenv("CACHE_MODE", "") != CacheModeEnum.SSD.value: + return + logger.info("Current cache mode is SSD, and file overwrite is not allowed in SSD mode, deleting exist directory" + " then create empty directory for this use case.") + for sub_path in SSD_DATA_PATH: + _del_related_dir(sub_path) + os.makedirs(sub_path, mode=0o550, exist_ok=True) + logger.info(f"Create dir:{sub_path}") + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + warnings.filterwarnings("ignore") + _clear_saved_model() + + rank_size = int(os.getenv("TRAIN_RANK_SIZE")) if os.getenv("TRAIN_RANK_SIZE") else None + interval = int(os.getenv("INTERVAL")) if os.getenv("INTERVAL") else None + train_steps = 1000 + eval_steps = 1000 + + try: + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 0))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + use_faae = bool(int(os.getenv("USE_FAAE", 0))) + except ValueError as err: + raise ValueError("please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " + "or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err + + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + logger.info(f"USE_DYNAMIC:{use_dynamic}") + init(train_steps=train_steps, eval_steps=eval_steps, + use_dynamic=use_dynamic, use_dynamic_expansion=use_dynamic_expansion) + + rank_id = mxrec_util.communication.hccl_ops.get_rank_id() + cfg = Config() + feature_spec_list_train = None + feature_spec_list_eval = None + if use_faae: + feature_spec_list_train = create_feature_spec_list(use_timestamp=True) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=True) + else: + feature_spec_list_train = create_feature_spec_list(use_timestamp=False) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=False) + + train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True, + dump_graph=True, is_use_faae=use_faae) + eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False, + dump_graph=False, is_use_faae=use_faae) + logger.info(f"train_batch: {train_batch}") + + if use_faae: + cfg.dev_vocab_size = cfg.dev_vocab_size // 2 + + optimizer_list = [get_dense_and_sparse_optimizer(cfg)] + + # note: variance_scaling_initializer only support HBM mode + emb_initializer = tf.constant_initializer(value=0.1) + sparse_hashtable = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="sparse_embeddings", + emb_initializer=emb_initializer, + **cfg.get_emb_table_cfg() + ) + if use_faae: + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, train_batch["timestamp"]) + + sparse_hashtable_list = [sparse_hashtable, sparse_hashtable] if use_multi_lookup else [sparse_hashtable] + train_model = model_forward(feature_spec_list_train, sparse_hashtable_list, train_batch, + is_train=True, modify_graph=MODIFY_GRAPH_FLAG) + eval_model = model_forward(feature_spec_list_eval, sparse_hashtable_list, eval_batch, + is_train=False, modify_graph=MODIFY_GRAPH_FLAG) + + dense_variables, sparse_variables = get_dense_and_sparse_variable() + trainable_varibles = [] + trainable_varibles.extend(dense_variables) + if use_dynamic_expansion: + trainable_varibles.append(tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB)[0]) + else: + trainable_varibles.extend(sparse_variables) + rank_size = mxrec_util.communication.hccl_ops.get_rank_size() + train_ops = [] + # multi task training + for loss, (dense_optimizer, sparse_optimizer) in zip([train_model.get("loss")], optimizer_list): + # do dense optimization + grads = dense_optimizer.compute_gradients(loss, var_list=trainable_varibles) + avg_grads = [] + for grad, var in grads[:-1]: + if rank_size > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad / 8.0, var)) + # apply gradients: update variables + train_ops.append(dense_optimizer.apply_gradients(avg_grads)) + + if use_dynamic_expansion: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + # do sparse optimization by addr + sparse_grads = list(grads[-1]) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + else: + # do sparse optimization + sparse_grads = list(grads[-1]) + print("sparse_grads_tensor:", sparse_grads) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + + + with tf.control_dependencies(train_ops): + train_ops = tf.no_op() + cfg.learning_rate = [cfg.learning_rate[0], cfg.learning_rate[1]] + + if MODIFY_GRAPH_FLAG: + modify_graph_and_start_emb_cache(dump_graph=True) + else: + start_asc_pipeline() + + hook_list = [] + if use_faae: + hook_evict = EvictHook(evict_enable=True, evict_time_interval=120) + hook_list.append(hook_evict) + if MODIFY_GRAPH_FLAG: # 该场景添加hook处理校验问题 + hook_list.append(GraphModifierHook(modify_graph=False)) + + if use_faae: + sess = tf.compat.v1.train.MonitoredTrainingSession( + hooks=hook_list, + config=sess_config(dump_data=False) + ) + sess.graph._unsafe_unfinalize() + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) + else: + sess = tf.compat.v1.Session(config=sess_config(dump_data=False)) + sess.run(tf.compat.v1.global_variables_initializer()) + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) + + epoch = 0 + cost_sum = 0 + qps_sum = 0 + best_auc_income = 0 + best_auc_mat = 0 + iteration_per_loop = 10 + + train_ops = util.set_iteration_per_loop(sess, train_ops, 10) + + i = 0 + while True: + i += 1 + logger.info(f"################ training at step {i * iteration_per_loop} ################") + start_time = time.time() + + try: + grad, loss, lr, global_step = sess.run([train_ops, train_model.get("loss"), + cfg.learning_rate, cfg.global_step]) + except tf.errors.OutOfRangeError: + logger.info(f"Encounter the end of Sequence for training.") + break + + end_time = time.time() + cost_time = end_time - start_time + qps = (1 / cost_time) * rank_size * cfg.batch_size * iteration_per_loop + cost_sum += cost_time + logger.info(f"step: {i * iteration_per_loop}; training loss: {loss}") + logger.info(f"step: {i * iteration_per_loop}; grad: {grad}") + logger.info(f"step: {i * iteration_per_loop}; lr: {lr}") + logger.info(f"global step: {global_step}") + logger.info(f"step: {i * iteration_per_loop}; current sess cost time: {cost_time:.10f}; current QPS: {qps}") + logger.info(f"training at step:{i * iteration_per_loop}, table[{sparse_hashtable.table_name}], " + f"table size:{sparse_hashtable.size()}, table capacity:{sparse_hashtable.capacity()}") + + if i % (train_steps // iteration_per_loop) == 0: + if interval is not None: + test_auc_income, test_auc_mat, test_mean_log_loss = evaluate_fix(i * iteration_per_loop) + else: + test_auc_income, test_auc_mat, test_mean_log_loss = evaluate() + print("Test auc income: {};Test auc mat: {} ;log_loss: {} ".format(test_auc_income, + test_auc_mat, test_mean_log_loss)) + best_auc_income = max(best_auc_income, test_auc_income) + best_auc_mat = max(best_auc_mat, test_auc_mat) + logger.info(f"training step: {i * iteration_per_loop}, best auc income: " + f"{best_auc_income} , best auc mat: {best_auc_mat}") + + + sess.close() + + terminate_config_initializer() + logger.info("Demo done!") diff --git a/examples/mmoe/model.py b/examples/mmoe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbb7ba8f2f823b6333a1c35cc0e150773f32298 --- /dev/null +++ b/examples/mmoe/model.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time +from easydict import EasyDict as edict + +import tensorflow as tf + + +model_cfg = edict() +model_cfg.loss_mode = "batch" +LOSS_OP_NAME = "loss" +LABEL_OP_NAME = "label" +VAR_LIST = "variable" +PRED_OP_NAME = "pred" + + +class MyModel: + def __init__(self, expert_num=8, expert_size=16, tower_size=8, gate_num=2): + + self.expert_num = expert_num + self.expert_size = expert_size + self.tower_size = tower_size + self.gate_num = gate_num + + + def expert_layer(self, _input): + param_expert = [] + for i in range(0, self.expert_num): + expert_linear = tf.layers.dense(_input, units=self.expert_size, activation=None, name=f'expert_layer_{i}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + param_expert.append(expert_linear) + return param_expert + + + def gate_layer(self, _input): + param_gate = [] + for i in range(0, self.gate_num): + gate_linear = tf.layers.dense(_input, units=self.expert_num, activation=None, name=f'gate_layer_{i}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + param_gate.append(gate_linear) + return param_gate + + + def tower_layer(self, _input, layer_name): + tower_linear = tf.layers.dense(_input, units=self.tower_size, activation='relu', + name=f'tower_layer_{layer_name}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + tower_linear_out = tf.layers.dense(tower_linear, units=2, activation=None, + name=f'tower_payer_out_{layer_name}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + return tower_linear_out + + + + + def build_model(self, + embedding=None, + dense_feature=None, + label=None, + is_training=True, + seed=None): + + with tf.variable_scope("mmoe", reuse=tf.AUTO_REUSE): + + dense_expert = self.expert_layer(dense_feature) + dense_gate = self.gate_layer(dense_feature) + + all_expert = [] + _slice_num = 0 + for i in range(0, self.expert_num): + slice_num_end = _slice_num + self.expert_size + cur_expert = tf.add(dense_expert[i], embedding[:, _slice_num:slice_num_end]) + cur_expert = tf.nn.relu(cur_expert) + all_expert.append(cur_expert) + _slice_num = slice_num_end + + expert_concat = tf.concat(all_expert, axis=1) + expert_concat = tf.reshape(expert_concat, [-1, self.expert_num, self.expert_size]) + + output_layers = [] + out_pred = [] + for i in range(0, self.gate_num): + slice_gate_end = _slice_num + self.expert_num + cur_gate = tf.add(dense_gate[i], embedding[:, _slice_num:slice_gate_end]) + cur_gate = tf.nn.softmax(cur_gate) + + cur_gate = tf.reshape(cur_gate, [-1, self.expert_num, 1]) + + cur_gate_expert = tf.multiply(x=expert_concat, y=cur_gate) + cur_gate_expert = tf.reduce_sum(cur_gate_expert, axis=1) + + out = self.tower_layer(cur_gate_expert, i) + out = tf.nn.softmax(out) + out = tf.clip_by_value(out, clip_value_min=1e-15, clip_value_max=1.0 - 1e-15) + output_layers.append(out) + out_pred.append(tf.nn.softmax(out[:, 1])) + _slice_num = slice_gate_end + trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='mmoe') + + label_income = label[:, 0:1] + label_mat = label[:, 1:] + + pred_income_1 = tf.slice(output_layers[0], [0, 1], [-1, 1]) + pred_marital_1 = tf.slice(output_layers[1], [0, 1], [-1, 1]) + + cost_income = tf.losses.log_loss(labels=tf.cast(label_income, tf.float32), predictions=pred_income_1, + epsilon=1e-4) + cost_marital = tf.losses.log_loss(labels=tf.cast(label_mat, tf.float32), predictions=pred_marital_1, + epsilon=1e-4) + + avg_cost_income = tf.reduce_mean(cost_income) + avg_cost_marital = tf.reduce_mean(cost_marital) + + loss = 0.5 * (avg_cost_income + avg_cost_marital) + + return {LOSS_OP_NAME: loss, + PRED_OP_NAME: out_pred, + LABEL_OP_NAME: label, + VAR_LIST: trainable_variables} diff --git a/examples/mmoe/op_impl_mode.ini b/examples/mmoe/op_impl_mode.ini new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/mmoe/optimizer.py b/examples/mmoe/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5469c705c7476f007aa56ffc6f8af85ee328fc05 --- /dev/null +++ b/examples/mmoe/optimizer.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address + + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.train.AdamOptimizer(learning_rate=cfg.learning_rate[0]) + use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate[1]) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1]) + + return dense_optimizer, sparse_optimizer diff --git a/examples/mmoe/run.sh b/examples/mmoe/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c1424435bac57a727cfc1980e25e08c6c5ba74c --- /dev/null +++ b/examples/mmoe/run.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +cur_path=$(dirname "$(readlink -f "$0")") + +so_path=$1 +mx_rec_package_path=$2 +hccl_cfg_json=$3 +dlrm_criteo_data_path=$4 +ip=$5 # no ranktable时传入该参数 + +interface="lo" +num_server=1 +local_rank_size=8 +num_process=$((num_server * local_rank_size)) +export TRAIN_RANK_SIZE=$num_process + +################# 参数配置 ###################### +export USE_DYNAMIC=0 # 0:静态shape;1:动态shape +export CACHE_MODE="HBM" # HBM;DDR;SSD +export USE_FAAE=0 # 0:关闭准入淘汰;1:开启准入淘汰 +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +################################################ +echo "CACHE_MODE:${CACHE_MODE}" + +export HCCL_CONNECT_TIMEOUT=1200 +export DLRM_CRITEO_DATA_PATH=${dlrm_criteo_data_path} +export PYTHONPATH=${mx_rec_package_path}:${so_path}:$PYTHONPATH +export LD_PRELOAD=/usr/lib64/libgomp.so.1 +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +export ASCEND_DEVICE_ID=0 +export RANK_ID_START=0 +export JOB_ID=10086 +export CUSTOMIZED_OPS_LIB_PATH=${so_path}/libcust_ops.so # Todo: please config +export MXREC_LOG_LEVEL="INFO" +export TF_CPP_MIN_LOG_LEVEL=3 +export ASCEND_GLOBAL_LOG_LEVEL=3 +#export USE_FAAE=1 +export ENABLE_FORCE_V2_CONTROL=1 + +export PROFILING_OPTIONS='{"output":"/home/yz/profiling", + "training_trace":"on", + "task_trace":"on", + "aicpu":"on", + "fp_point":"", + "bp_point":"", + "aic_metrics":"PipeUtilization"}' + +RANK_ID_START=0 + +export MXREC_MODE="ASC" +echo "MXREC_MODE is $MXREC_MODE" +export py=main_mxrec.py +echo "py is $py" + +# 区分ranktable和no ranktable +if [ -n "$ip" ]; then + # no ranktable分支 + echo "Current is no ranktable solution." + echo "Input node ip: $ip, please make sure this ip is available." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=60001 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" +else + # ranktable分支 + echo "Current is ranktable solution, hccl json file:${hccl_cfg_json}" + export RANK_SIZE=$num_process + echo "RANK_SIZE=${RANK_SIZE}, please make sure hccl configuration json file match this parameter" + export RANK_TABLE_FILE=${hccl_cfg_json} +fi + +echo "use horovod to start tasks" +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' + +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${num_process}p.log diff --git a/examples/ps_adapt_to_mxrec/ps_adapt_to_mxrec.md b/examples/ps_adapt_to_mxrec/ps_adapt_to_mxrec.md new file mode 100644 index 0000000000000000000000000000000000000000..431133b73a36d201d43c3377627e92db1a770f04 --- /dev/null +++ b/examples/ps_adapt_to_mxrec/ps_adapt_to_mxrec.md @@ -0,0 +1,750 @@ +# 版本信息 + +1. ps-lite + + [GitHub - dmlc/ps-lite: A lightweight parameter server interface](https://github.com/dmlc/ps-lite) + + commit 11b42c08a357d4ea5924403daa357587f4d8b5e2(包含本commit及之后都可以) + +2. mxRec + + [mxrec: 华为昇腾-MindX 推荐SDK - Gitee.com](https://gitee.com/ascend/mxrec/tree/develop/) + + commit ae36047f1dda8c03fa849184205bdc8bcfb4a137 + +**注:ps-lite不支持多表存储,所以本文档以单表训练场景为例。** + +# 适配流程 + +## ps-lite + +### 下载ps-lite代码 + +```shell +# 在mxrec根目录下 +cd mxrec/src +mkdir 3rdparty +cd 3rdparty +git clone https://github.com/dmlc/ps-lite.git +``` + +### 修改ps-lite/make/deps.mk + +* 调整为不删除源码包,减少重复编译耗时 +* 调整依赖版本与ps-lite/CMakeLists.txt一致。其中protobuf 3.8.0为tensorflow 1.15适配版本,用户可根据自身tf版本调整。 + +```makefile +# protobuf +PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h +${PROTOBUF}: + $(eval FILE=protobuf-cpp-3.8.0.tar.gz) + $(eval DIR=protobuf-3.8.0) + rm -rf $(DIR) + $(WGET) -nc $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE) + cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install + rm -rf $(DIR) + +# zmq +ZMQ = ${DEPS_PATH}/include/zmq.h + +${ZMQ}: + $(eval FILE=zeromq-4.3.2.tar.gz) + $(eval DIR=zeromq-4.3.2) + rm -rf $(DIR) + $(WGET) -nc $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) + cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install + rm -rf $(DIR) + +# lz4 +LZ4 = ${DEPS_PATH}/include/lz4.h +${LZ4}: + $(eval FILE=lz4-r129.tar.gz) + $(eval DIR=lz4-r129) + rm -rf $(DIR) + wget -nc $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) + cd $(DIR) && $(MAKE) && PREFIX=$(DEPS_PATH) $(MAKE) install + rm -rf $(DIR) + +# cityhash +CITYHASH = ${DEPS_PATH}/include/city.h +${CITYHASH}: + $(eval FILE=cityhash-1.1.1.tar.gz) + $(eval DIR=cityhash-1.1.1) + rm -rf $(DIR) + wget -nc $(URL1)/$(FILE)&& tar --no-same-owner -zxf $(FILE) + cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --enable-sse4.2 && $(MAKE) CXXFLAGS="-g -O3 -msse4.2" && $(MAKE) install + rm -rf $(DIR) +``` + +### 安装依赖 + +* protobuf:需要确保版本与tensorflow的一致,在tensorflow目录中搜索`GOOGLE_PROTOBUF_VERSION`查看protobuf版本 + +* zeromq:参考github,版本如ps-lite/make/deps.mk所示 + +### 准备KVServerMxRecHandle源码 + +在ps-lite/include/ps/kv_app.h中增加如下代码: + +```c++ +/** + * \brief for mxrec embedding storage + */ +template +struct KVServerMxRecHandle { + void operator()( + const KVMeta& req_meta, const KVPairs& req_data, KVServer* server) { + LL << "KVServerMxRecHandle, customerId:" << req_meta.customer_id << ", push:" << req_meta.push << ", pull:" << req_meta.pull; + auto es = std::getenv("EMB_SIZE"); + if (es == nullptr) { + throw std::runtime_error("EMB_SIZE environment variable not found, please export"); + } + int embeddingSize = std::stoi(es); + size_t keyCnt = req_data.keys.size(); + KVPairs res; + + if (req_meta.pull) { + LL << "pull, customerId:" << req_meta.customer_id << ", keys.size:" << keyCnt << ", embeddingSize:" << embeddingSize; + res.keys = req_data.keys; + res.vals.resize(keyCnt * embeddingSize); // flatten all data + for (size_t i = 0; i < keyCnt; ++i) { + Key key = req_data.keys[i]; + std::vector emb = store[key]; + if (emb.size() == 0) { + emb = std::vector(embeddingSize, 0); + } else if (emb.size() != embeddingSize) { + throw std::runtime_error("embedding size in server not equal to request"); + } + for (int j = 0; j < embeddingSize; j++) { + res.vals[i * embeddingSize + j] = emb[j]; + } + } + } else if (req_meta.push) { + LL << "push, customerId:" << req_meta.customer_id << ", keys.size:" << keyCnt << ", vals.size:" << req_data.vals.size() << ", embeddingSize:" << embeddingSize; + for (size_t i = 0; i < keyCnt; i++) { + Key key = req_data.keys[i]; + std::vector tmp(embeddingSize); + for (size_t j = 0; j < embeddingSize; j++) + { + tmp[j] = res.vals[i * embeddingSize + j]; + } + store[key] = tmp; + } + } else { + LL << "error: request neither push or pull"; + throw std::runtime_error("request neither push or pull"); + } + + server->Response(req_meta, res); + } + std::unordered_map> store; +}; +``` + +### 准备scheduler、server、worker源码 + +* ps-lite/tests/test_scheduler.cc + + ```c++ + #include + #include "ps/ps.h" + + using namespace ps; + + void RunSchedular(int appId) { + // start system + LL << "start schedular, appId:" << appId; + Start(appId); + Finalize(appId, true); + LL << "quit schedular, appId:" << appId; + } + + int main(int argc, char *argv[]) { + int appId = std::stoi(argv[1]); + RunSchedular(appId); + return 0; + } + ``` + +* ps-lite/tests/test_server.cc + + ```c++ + #include + #include "ps/ps.h" + + using namespace ps; + + void StartServer(int serverId) { + if (!IsServer()) { + return; + } + auto server = new KVServer(serverId); + server->set_request_handle(KVServerMxRecHandle()); + RegisterExitCallback([server](){ delete server; }); + } + + void RunServer(int appId) { + LL << "start server, appId:" << appId; + Start(appId); + StartServer(appId); + // stop system + Finalize(appId, true); + LL << "quit server, appId:" << appId; + } + + int main(int argc, char *argv[]) { + int appId = std::stoi(argv[1]); + RunServer(appId); + return 0; + } + + ``` + +* ps-lite/tests/test_worker.cc + + ```c++ + #include + #include "ps/ps.h" + + using namespace ps; + using std::vector; + + + void RunWorker(int appId, int customerId) { + LL << "start worker, appId:" << appId << ", customerId:" << customerId; + Start(appId); + if (!IsWorker()) { + return; + } + KVWorker kv(appId, customerId); + + // init + int num = 10000; + int embSize = 2; + vector lens(num, embSize); + vector keys(num); + vector vals(num * embSize); + int rank = MyRank(); + srand(rank + 7); + for (int i = 0; i < num; ++i) { + keys[i] = kMaxKey / num * i + customerId; + for (int j = 0; j < embSize; ++j) + { + vals[i * embSize + j] = rand() % 1000; + } + } + + // push + LL << "start push"; + kv.Wait(kv.Push(keys, vals)); + + // pull + LL << "start pull"; + std::vector rets; + kv.Wait(kv.Pull(keys, &rets)); + + LL << "start validation"; + float res = 0; + for (int i = 0; i < num; ++i) { + for (int j = 0; j < embSize; ++j) { + if (abs(vals[i * embSize + j] - rets[i * embSize + j]) > std::numeric_limits::epsilon()) { + LL << "error: embedding from server not equal to original data"; + Finalize(appId, true); + return; + } + } + } + + // stop system + Finalize(appId, true); + LL << "stop worker, appId:" << appId << ", customerId:" << customerId; + } + + int main(int argc, char *argv[]) { + int customerId = std::stoi(argv[1]); + std::thread t0(RunWorker, 0, customerId); + t0.join(); + return 0; + } + ``` + +### 修改ps-lite/tests/CMakeLists.txt + +修改为如下代码: + +```makefile +add_executable(test_schedular test_schedular.cc) +target_link_libraries(test_schedular pslite) + +add_executable(test_server test_server.cc) +target_link_libraries(test_server pslite) + +add_executable(test_worker test_worker.cc) +target_link_libraries(test_worker pslite) +``` + +### 修改ps-lite/CMakeLists.txt + +增加如下代码: + +```cmake +target_link_libraries(pslite PUBLIC pthread) +``` + +### 编译scheduler、server、worker + +在ps-lite目录下执行 + +```shell +mkdir build +cd build +cmake .. +make -j4 +``` + +### 准备scheduler、server、worker启动脚本 + +* ps-lite/start_service.sh + + ```shell + #!/bin/bash + # set -x + if [ $# -lt 2 ]; then + echo "usage: $0 bin_schedular bin_server" + exit -1; + fi + + export DMLC_NUM_SERVER=1 + export DMLC_NUM_WORKER=1 + bin_schedular=$1 + bin_server=$2 + + # start the scheduler + export DMLC_PS_ROOT_URI='127.0.0.1' + export DMLC_ROLE='scheduler' + export DMLC_PS_ROOT_PORT=8000 + ${bin_schedular} 0 & + + # start servers + export DMLC_ROLE='server' + ${bin_server} 0 & + + wait + ``` + +* ps-lite/start_worker.sh + + ```shell + #!/bin/bash + # set -x + if [ $# -lt 1 ]; then + echo "usage: $0 bin_worker" + exit -1; + fi + + export DMLC_NUM_SERVER=1 + export DMLC_NUM_WORKER=1 + bin_worker=$1 + + # scheduler info + export DMLC_PS_ROOT_URI='127.0.0.1' + export DMLC_PS_ROOT_PORT=8000 + export DMLC_ROLE='worker' + ${bin_worker} 0 & + + wait + ``` + + +### 编译ps-lite + +在ps-lite目录下 + +```shell +mkdir build +cd build +cmake .. +make -j8 +``` + +### 测试基础功能是否正常 + +将编译好的test文件复制到ps-lite目录,执行: + +```shell +#分别执行 +./start_service.sh ./test_schedular ./test_server +./start_worker.sh ./test_worker +``` + +无报错表示正常。 + +## mxrec + +### 调整ps-lite + +1. 删除ps-lite/build +2. 修改ps-lite/CMakeLists.txt,注释掉`add_subdirectory(tests)` + +搜索以下代码片段,新增、替换源码。 + +### src/build.sh + +```makefile +cmake -DCMAKE_BUILD_TYPE=Release \ + -DTF_PATH="$1" \ + -DOMPI_PATH="$(whereis openmpi)" \ + -DPYTHON_PATH="$python_path" \ + -DEASY_PROFILER_PATH=/ \ + -DASCEND_PATH="$ascend_path" \ + -DABSEIL_PATH="$1" \ + -DSECUREC_PATH="$2"/../opensource/securec \ + -DCMAKE_INSTALL_PREFIX="$2"/output \ + -DBUILD_CUST="$3" .. \ + -DDEPS_PATH="$2"/src/3rdparty/ps-lite # new +``` + +### src/CMakeLists.txt + +```cmake +add_subdirecotry(dataset_tf) +add_subdirecotry(core/3rdparty/ps-lite) # new +``` + +### src/core/CMakeLists.txt + +```cmake +file(GLOB_RECURSE MXREC_SRC ./*.cpp ./*.h) +add_library(ASC SHARED ${MXREC_SRC}) + +target_include_directories(ASC PUBLIC 3rdparty/ps-lite/include) # new +``` + +```makefile +target_link_libraries(ASC PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host _ock_ctr_common + pslite # new +) +``` + +### src/core/ps_store/ps_store.h**(新增)** + +```c +#ifndef MXREC_PS_STORE_H +#define MXREC_PS_STORE_H + +#include +#include + +#include "l3_storage/l3_storage.h" +#include "ps/ps.h" // must set behind any mxrec header file, otherwise will compile fail + +using MxRec::L3Storage; +using ps::KVWorker; +using std::map; +using std::shared_ptr; +using std::string; + +namespace MxRec { +class PSStore : public L3Storage { +public: + PSStore(int rankId); + + bool IsTableExist(const string& tableName); + + bool IsKeyExist(const string& tableName, emb_cache_key_t key); + + void CreateTable(const string& tableName, vector savePaths, uint64_t maxTableSize); + + int64_t GetTableAvailableSpace(const string& tableName); + + void InsertEmbeddingsByAddr(const string& tableName, vector& keys, vector& embeddingsAddr, + uint64_t extEmbeddingSize); + + void DeleteEmbeddings(const string& tableName, vector& keys); + + vector> FetchEmbeddings(const string& tableName, vector& keys); + + void Save(int step); + + void Load(const string& tableName, vector savePaths, uint64_t maxTableSize, int step); + + void Start(); + + void Stop(); + + int64_t GetTableUsage(const string& tableName); + + vector>> ExportTableKey(); + +private: + // ps-lite not support multiple table yet, thus this example code only use one client + int appId = 0; + int customerId = 0; + + // table --> client + map>> cliMap; +}; +} // namespace MxRec +#endif // MXREC_PS_STORE_H +``` + +### src/core/ps_store/ps_store.cpp**(新增)** + +```c++ +#include "ps_store.h" + +using MxRec::PSStore; +using MxRec::emb_cache_key_t; + +struct KeyWithIdx { + emb_cache_key_t key; + size_t index; +} + +bool CompareKeyWithIdx(KeyWithIdx a, KeyWithIdx b) { + return a.key < b.key; +} + +PSStore::PSStore(int rankId) +{ + this->customerId = rankId + std::stoi(std::getenv("REC_WORKER_ID_START_IDX")); +} + +bool PSStore::IsTableExist(const string& tableName) +{ + auto iter = cliMap.find(tableName); + if (iter == cliMap.end()) { + return false; + } + return true; +} + +bool PSStore::IsKeyExist(const string& tableName, emb_cache_key_t key) +{ + auto iter = cliMap.find(tableName); + if (iter == cliMap.end()) { + LOG_DEBUG("table:{} not create yet", tableName); + throw std::runtime_error("table not create yet"); + } + + auto worker = cliMap[tableName]; + vector keys = {key}; + vector rets; + worker->Wait(worker->Pull(keys, &rets)); + if (rets.size() > 0) { + return true; + } + return false; +} + +void PSStore::CreateTable(const string& tableName, vector savePaths, uint64_t maxTableSize) { + static bool alreadyCreate = false; + if (alreadyCreate) { + throw runtime_error("ps-lite not support multiple table yet, thus this example code only support one table"); + } + LOG_DEBUG("start create table:{}, init ps-lite client, appId:{}, customerId:{}", tableName, appId, customerId); + ps::Start(appId); + auto worker = make_shared>(appId, customerId); + cliMap[tableName] = worker; + LOG_DEBUG("finish create table:{}, worker appId:{}, customerId:{}", tableName, appId, customerId); + alreadyCreate = true; +} + +int64_t PSStore::GetTableAvailableSpace(const string& tableName) +{ + // ps-lite don't have this api + // thus always available + return 1000000000000; +} + +void PSStore::InsertEmbeddingsByAddr(const string& tableName, vector& keys, + vector& embeddingsAddr, uint64_t extEmbeddingSize) +{ + if (keys.size() == 0) { + return; + } + + auto iter = cliMap.find(tableName); + if (iter == cliMap.end()) { + LOG_DEBUG("table:{} not create yet", tableName); + throw std::runtime_error("table not create yet"); + } + auto psCli = cliMap[tableName]; + + // note: ps-lite need keys in order + vector elements; + for (size_t i = 0; i < keys.size(); i++) { + KeyWithIdx e = {keys[i], i}; + elements.push_back(e); + } + sort(elements.begin(), elements.end(), CompareKeyWithIdx); + vector sortedKeys; + vector sortedEmbeddingsAddr; + for (size_t i = 0; i < elements.size(); i++) { + sortedKeys.push_back(elements[i].key); + sortedEmbeddingsAddr.push_back(embeddingsAddr[elements[i].index]); + } + + vector lens(keys.size(), extEmbeddingSize); + vector vals(embeddingsAddr.size() * extEmbeddingSize); + for (size_t i = 0; i < embeddingsAddr.size(); i++) + { + auto rc = memcpy_s(vals.data()+i*extEmbeddingSize, extEmbeddingSize, sortedEmbeddingsAddr[i], extEmbeddingSize); + if (rc !=0){ + throw std::runtime_error("copy embedding data failed"); + } + } + + LOG_DEBUG("start push to server, table:{}, keys.size:{}, vals.size:{}", tableName, keys.size(), vals.size()); + int timeStamp = psCli->Push(keys, vals); + psCli->Wait(timeStamp); + + LOG_DEBUG("end push embedding to server, table:{}", tableName); +} + +void PSStore::DeleteEmbeddings(const string& tableName, vector& keys) +{ + LOG_WARN("ps-lite don't have delete function, just return"); + return; +} + +vector> PSStore::FetchEmbeddings(const string& tableName, vector& keys) +{ + LOG_DEBUG("start pull embedding to server, table:{}, keys.size:{}", tableName, keys.size()); + if (keys.size() == 0) { + return vector>; + } + + + auto iter = cliMap.find(tableName); + if (iter == cliMap.end()) { + LOG_DEBUG("table:{} not create yet", tableName); + throw std::runtime_error("table not create yet"); + } + auto psCli = cliMap[tableName]; + + // note: ps-lite need keys in order + vector elements; + for (size_t i = 0; i < keys.size(); i++) { + KeyWithIdx e = {keys[i], i}; + elements.push_back(e); + } + sort(elements.begin(), elements.end(), CompareKeyWithIdx); + vector sortedKeys; + for (size_t i = 0; i < elements.size(); i++) { + sortedKeys.push_back(elements[i].key); + } + + // input lens will be stuck at req_data.lens, so we use environment variable to work around + std::vector rets; + psCli->Wait(psCli->Pull(sortedKeys, &rets)); + + LOG_DEBUG("finish pull embedding, table:{}, embedding len:{}", tableName, rets.size()); + if (rets.size() % keys.size() != 0) { + LOG_ERROR("can't split received embedding equally, keys.size:{}, embeddings.size:{}", keys.size(), rets.size()); + throw std::runtime_error("embedding from server incomplete"); + } + + auto extEmbSize = rets.size() % keys.size(); + vector> embs(keys.size()); + for (size_t i = 0; i < elements.size(); i++) { + auto emb = embs[elements[i].index]; + emb.insert(emb.cbegin(), rets.cbegin() + i * extEmbSize, rets.cend() + (i + 1) * extEmbSize); + } + + LOG_DEBUG("end pull embedding to server, table:{}", tableName); + return embs; +} + +void PSStore::Save(int step) +{ + LOG_WARN("ps-lite don't have save function, just return"); +} + +void PSStore::Load(const string& tableName, vector savePaths, uint64_t maxTableSize, int step) +{ + LOG_WARN("ps-lite don't have save function, just return"); +} + +void PSStore::Start() +{ + LOG_INFO("start ps store"); +} + +void PSStore::Stop() +{ + LOG_INFO("start stop ps store"); + ps::Finalize(appId, true); + LOG_INFO("finish stop ps store"); +} + +int64_t PSStore::GetTableUsage(const string& tableName) +{ + LOG_WARN("ps-lite don't have GetTableUsage function, just return 0"); + return 0; +} + +vector>> PSStore::ExportTableKey() +{ + LOG_WARN("ps-lite don't have export key function, just return empty result"); + return vector>>(); +} +``` + +### src/core/hybrid_mgmt/hybrid_mgmt.cpp + +```c++ +#include "ps_store/ps_store.h" // new +``` + +```c++ +if (isL3StorageEnabled) { + cacheManager = Singleton::GetInstance(); + // 用户可实现L3Storage接口替换SSDEngine以对接外部存储服务 + auto psStore = std::make_shared(mgmtRankInfo.rankId); // replace + cacheManager->Init(embCache, mgmtEmbInfo, psStore); // replace + EmbeddingMgmt::Instance()->SetCacheManagerForEmbTable(cacheManager); +} +``` + +### 模型代码 + +以dcnV2为例,在run.sh中新增以下环境变量。 + +```shell +# ps-lite info +export DMLC_NUM_SERVER=1 +export DMLC_NUM_WORKER=8 # ausume we run 8 train process + +# scheduler info +export DMLC_PS_ROOT_URI='127.0.0.1' # user can set to remote server +export DMLC_PS_ROOT_PORT=8000 + +# set role as workers +export DMLC_ROLE='worker' + +# mark worker id for train process between multiple train server +# e.g. server A, worker id range [REC_WORKER_ID_START_IDX, +1, ..., +n]; server B, worker id range [REC_WORKER_ID_START_IDX +(n+1), +(n+2), ...] +export REC_WORKER_ID_START_IDX=0 +``` + +在ps-lite目录拉起存储服务 + +```shell +./start_service.sh ./test_schedular ./test_server +``` + +在模型目录拉起训练 + +```shell +# 修改缓存模式为SSD(按上述mxrec源码修改步骤,SSDEngine已被替换为ps-lite,为了不影响对外接口,未修改对外暴露的ssd参数,用户可自行修改) +export CACHE_MODE="SSD" + +./run.sh $LIBSAC_PATH $PYTHON_PATH $HCCL_JSON_PATH $DATA_PATH +``` + + + + + diff --git a/examples/rec_infer/README.md b/examples/rec_infer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..573ecafca07e01b16845302dcbd314e075640e33 --- /dev/null +++ b/examples/rec_infer/README.md @@ -0,0 +1,134 @@ +# 推理环境部署 +一、安装依赖包:

+安装开发套件包Ascend-cann-toolkit_{version}_linux-{arch}.run

+安装框架插件包Ascend-cann-tfplugin_{version}_linux-{arch}.run

+安装其他依赖包:

+|依赖包 | 版本限制| +|:---|:---:| +|gcc,g++|8.4及以上版本| +|zip,unzip,libtool,automake|无特定版本要求| +|python|3.7.5| +|TensorFlow| 1.15.0| +|tensorflow-serving-api|1.15.0| +|future|无特定版本要求| +|bazel|0.24.1| +|camake|3.14.0| +|swig|若操作系统为"aarch64",软件安装版本需大于或等于3.0.12。若操作系统架构为"X86_64",软件安装版本需大于或等于4.0.1| +|java|jdk-11| +||| + +二、编译serving +1. 下载TF-serving源码:https://github.com/tensorflow/serving/archive/1.15.0.zip +2. 解压后进入源码目录 +3. 添加TF-serving第三方依赖 + +a)执行如下命令,在“serving-1.15.0/third_party”目录下创建“tf_adapter”文件夹并进入。 +>cd third_party/
+mkdir tf_adapter
+cd tf_adapter
+b)执行如下命令,在“tf_adapter”文件夹下拷贝存放“libpython3.7m.so.1.0”文件,并创建软链接。 +> cp /usr/local/python3.7.5/lib/libpython3.7m.so.1.0 .
+ln -s libpython3.7m.so.1.0 libpython3.7m.so
+ +c.执行如下命令,在“tf_adapter”文件夹下拷贝存放“_tf_adapter.so”文件,并将“_tf_adapter.so”文件名修改为“lib_tf_adapter.so”。 +>cp /home/HwHiAiUser/Ascend/tfplugin/latest/python/site-packages/npu_bridge/_tf_adapter.so .
+mv _tf_adapter.so lib_tf_adapter.so
+ +4. 编译空的libtensorflow_framework.so、_pywrap_tensorflow_internal.so文件. + +a. 在“tf_adapter”文件夹下,执行如下命令。 +>vim CMakeLists.txt
+ +b. 写入如下内容保存。 +```text +file(TOUCH ${CMAKE_CURRENT_BINARY_DIR}/stub.c) +add_library(_pywrap_tensorflow_internal SHARED ${CMAKE_CURRENT_BINARY_DIR}/stub.c) +add_library(tensorflow_framework SHARED ${CMAKE_CURRENT_BINARY_DIR}/stub.c) +``` + +c.执行:wq!命令保存文件并退出。 +d.执行如下命令,编译出空的.so文件。 +> mkdir temp
+cd temp
+cmake ..
+make
+mv lib_pywrap_tensorflow_internal.so ../_pywrap_tensorflow_internal.so
+mv libtensorflow_framework.so ../libtensorflow_framework.so
+cd ..
+ln -s libtensorflow_framework.so libtensorflow_framework.so.1
+ +e.配置环境命令。 +```text +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)
+``` + +5. 在“tf_adapter”文件夹下创建BUILD文件。 写入如下内容。 +```text +licenses(["notice"]) # BSD/MIT. + +cc_import( + name = "tf_adapter", + shared_library = "lib_tf_adapter.so", + visibility = ["//visibility:public"] +) + +cc_import( + name = "tf_python", + shared_library = "libpython3.7m.so", + visibility = ["//visibility:public"] +) +``` + +6. 修改“serving-1.15.0/tensorflow_serving/model_servers/”路径下的BUILD文件,在“cc_binary”中添加如下加粗内容。 + +>cc_binary(
+name = "tensorflow_model_server",
+     stamp = 1,
+     visibility = [
+         ":testing",
+         "//tensorflow_serving:internal",
+     ],
+     deps = [
+         ":tensorflow_model_server_main_lib",
+         __"//third_party/tf_adapter:tf_adapter",__
+         __"//third_party/tf_adapter:tf_python",__
+         __"@org_tensorflow//tensorflow/compiler/jit:xla_cpu_jit",__
+     ],
+)
+ +7. TF Serving,在TF Serving安装目录“serving-1.15.0”下执行如下命令,编译TF Serving。 + +> bazel --output_user_root=/opt/tf_serving build -c opt --distdir=../depends --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" tensorflow_serving/model_servers:tensorflow_model_server
+如果编译过程中遇到依赖包下载失败问题,可手动下载,TF serving编译依赖包(https://www.hiascend.com/document/detail/zh/canncommercial/80RC1/developmentguide/moddevg/onlineinfer1/atlastfserv_26_0011.html) + +8. 建立软连接。 +> ln -s /opt/tf_serving/{tf_serving_ID}/execroot/tf_serving/bazel-out/xxx-opt/bin/tensorflow_serving/model_servers/tensorflow_model_server /usr/local/bin/tensorflow_model_server
+ ++ {tf_serving_ID}为一串如“063944eceea3e72745362a0b6eb12a3c”的无规则字符。请根据实际进行填写。 ++ xxx-opt为工具自动生成文件夹,具体显示请以实际为准。 + +# 脚本工具介绍 +server.sh/client.sh +启动服务脚本/客户端请求服务器脚本 + +1. 启动tf-serving server方法 +进入目录 tf_serving_inerence +> 更改server.sh中模型路径model_base_path为导出的savedModel路径,
+> 将编译tf_serving的第三方依赖tf_adapter路径加入环境变量,export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/xxx/xxx/serving-1.15.0/third_party/tf_adapter/,
+> source /usr/local/Ascend/ascend-toolkit/set_env.sh
+> sh server.sh
+ +若日志中显示Running gRPC ModelServer at 0.0.0.0:xxxx则表示启动成功 +2.请求服务器方法 +执行脚本:sh client.sh +推理成功会打印端到端时延 + +# 使用切图工具 +1.进入目录:graph_patition,修改gen_config.py中的模型目录 +2.执行 python3 gen_config.py,使用生成的test1.cfg文件启动模型,使用方法如下: +> python3 gen_config.py --output_path . --output_filename test1.cfg --model_path savedmodel_path
++ 参数解释:output_path(输出路径),output_filename(输出文件名),model_path(输入模型路径)
++ 得到输出文件后,替换服务启动脚本中--platform_config_file参数选项即可生效 + +#性能优化 +1. 具体参考optimize目录下的文件 \ No newline at end of file diff --git a/examples/rec_infer/client.py b/examples/rec_infer/client.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6f1cb18ac4b07aab552ddbee8e43059a60aaca --- /dev/null +++ b/examples/rec_infer/client.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import time + +import grpc +import numpy as np + +import tensorflow as tf +from input_config import config +from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc + + +class PredictModelGrpc(): + def __init__( + self, + model_name, + inputs, + input_types, + output_name, + socket="xxx.xxx.xxx.xxx:8500", + ): + self.socket = socket + self.model_name = model_name + self.inputs = inputs + self.input_types = input_types + self.output_name = output_name + self.request, self.stub = self.__get_request() + + def inference(self): + for name in self.inputs: + self.request.inputs[name].CopyFrom( + tf.make_tensor_proto(self.inputs[name], dtype=self.input_types[name]) + ) + + for _ in range(100): + result = self.stub.Predict.future(self.request, 1000.0) + result.result() + + def __get_request(self): + channel = grpc.insecure_channel( + self.socket, + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 1024), + ("grpc.max_receive_message_length", 1024 * 1024 * 1024), + ], + ) + stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) + request = predict_pb2.PredictRequest() + request.model_spec.name = self.model_name + request.model_spec.signature_name = "serving_default" + + return request, stub + + +FIELD_TYPE = "dtype" +FIELD_SHAPE = "shape" + + +def gen_inputs(): + inputs = {} + input_types = {} + for name in config: + input_types[name] = config[name][FIELD_TYPE] + if config[name][FIELD_TYPE] == tf.int32: + inputs[name] = np.random.randint(0, 100, size=config[name][FIELD_SHAPE]) + elif config[name][FIELD_TYPE] == tf.float32: + inputs[name] = np.random.randint(0, 2, size=config[name][FIELD_SHAPE]) * 1.0 + return inputs, input_types + + +if __name__ == "__main__": + input_datas, types = gen_inputs() + model = PredictModelGrpc( + model_name="saved_model", + inputs=input_datas, + input_types=types, + output_name="", + socket="127.0.0.1:9999", + ) + + model.inference() diff --git a/examples/rec_infer/client.sh b/examples/rec_infer/client.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d3169c2258b05581b926b753ab25dee76f5da45 --- /dev/null +++ b/examples/rec_infer/client.sh @@ -0,0 +1,8 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# Description: startup client + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +unset http_proxy +unset https_proxy +python3 client.py \ No newline at end of file diff --git a/examples/rec_infer/input_config.py b/examples/rec_infer/input_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8fff6ceb4267f3896ce15941851bb76089b34996 --- /dev/null +++ b/examples/rec_infer/input_config.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +BATCH_SIZE = 9600 +config = { + "feat_0": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_0"}, + "feat_1": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_1"}, + "feat_2": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_2"}, + "feat_3": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_3"}, + "feat_4": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_4"}, + "feat_5": {"dtype": tf.float32, "shape": [BATCH_SIZE, 32], "name": "feat_5"}, + "feat_6": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_6"}, + "feat_7": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_7"}, + "feat_8": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_8"}, + "feat_9": {"dtype": tf.int32, "shape": [BATCH_SIZE, 16], "name": "feat_9"}, + "feat_10": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_10"}, + "feat_11": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_11"}, + "feat_12": {"dtype": tf.float32, "shape": [BATCH_SIZE, 480], "name": "feat_12"}, + "feat_13": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_13"}, + "feat_14": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_14"}, + "feat_15": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_15"}, + "feat_16": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_16"}, + "feat_17": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_17"}, + "feat_18": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_18"}, + "feat_19": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_19"}, + "feat_20": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_20"}, + "feat_21": {"dtype": tf.float32, "shape": [BATCH_SIZE, 32], "name": "feat_21"}, + "feat_22": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_22"}, + "feat_23": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_23"}, + "feat_24": {"dtype": tf.int32, "shape": [BATCH_SIZE, 10], "name": "feat_24"}, + "feat_25": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_25"}, + "feat_26": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_26"}, + "feat_27": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_27"}, + "feat_28": {"dtype": tf.int32, "shape": [BATCH_SIZE, 36], "name": "feat_28"}, + "feat_29": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_29"}, + "feat_30": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_30"}, + "feat_31": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_31"}, + "feat_32": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_32"}, + "feat_33": {"dtype": tf.float32, "shape": [BATCH_SIZE, 256], "name": "feat_33"}, + "feat_34": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_34"}, + "feat_35": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_35"}, + "feat_36": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_36"}, + "feat_37": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_37"}, + "feat_38": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_38"}, + "feat_39": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_39"}, + "feat_40": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_40"}, + "feat_41": {"dtype": tf.float32, "shape": [BATCH_SIZE, 32], "name": "feat_41"}, + "feat_42": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_42"}, + "feat_43": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_43"}, + "feat_44": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_44"}, + "feat_45": {"dtype": tf.int32, "shape": [BATCH_SIZE, 7], "name": "feat_45"}, + "feat_46": {"dtype": tf.int32, "shape": [BATCH_SIZE, 4], "name": "feat_46"}, + "feat_47": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_47"}, + "feat_48": {"dtype": tf.int32, "shape": [BATCH_SIZE, 4], "name": "feat_48"}, + "feat_49": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_49"}, + "feat_50": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_50"}, + "feat_51": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_51"}, + "feat_52": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_52"}, + "feat_53": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_53"}, + "feat_54": {"dtype": tf.int32, "shape": [BATCH_SIZE, 100], "name": "feat_54"}, + "feat_55": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_55"}, + "feat_56": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_56"}, + "feat_57": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_57"}, + "feat_58": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_58"}, + "feat_59": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_59"}, + "feat_60": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_60"}, + "feat_61": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_61"}, + "feat_62": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_62"}, + "feat_63": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_63"}, + "feat_64": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_64"}, + "feat_65": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_65"}, + "feat_66": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_66"}, + "feat_67": {"dtype": tf.float32, "shape": [BATCH_SIZE, 192], "name": "feat_67"}, + "feat_68": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_68"}, + "feat_69": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_69"}, + "feat_70": {"dtype": tf.float32, "shape": [BATCH_SIZE, 6, 32], "name": "feat_70"}, + "feat_71": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_71"}, + "feat_72": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_72"}, + "feat_73": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_73"}, + "feat_74": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_74"}, + "feat_75": {"dtype": tf.int32, "shape": [BATCH_SIZE, 10], "name": "feat_75"}, + "feat_76": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_76"}, + "feat_77": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_77"}, + "feat_78": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_78"}, + "feat_79": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_79"}, + "feat_80": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_80"}, + "feat_81": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_81"}, + "feat_82": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_82"}, + "feat_83": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_83"}, + "feat_84": {"dtype": tf.float32, "shape": [BATCH_SIZE, 32], "name": "feat_84"}, + "feat_85": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_85"}, + "feat_86": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_86"}, + "feat_87": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_87"}, + "feat_88": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_88"}, + "feat_89": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_89"}, + "feat_90": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_90"}, + "feat_91": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_91"}, + "feat_92": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_92"}, + "feat_93": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_93"}, + "feat_94": {"dtype": tf.int32, "shape": [BATCH_SIZE, 36], "name": "feat_94"}, + "feat_95": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_95"}, + "feat_96": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_96"}, + "feat_97": {"dtype": tf.float32, "shape": [BATCH_SIZE, 320], "name": "feat_97"}, + "feat_98": {"dtype": tf.float32, "shape": [BATCH_SIZE, 1], "name": "feat_98"}, + "feat_99": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_99"}, + "feat_100": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_100"}, + "feat_101": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_101"}, + "feat_102": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_102"}, + "feat_103": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_103"}, + "feat_104": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_104"}, + "feat_105": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_105"}, + "feat_106": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_106"}, + "feat_107": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_107"}, + "feat_108": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_108"}, + "feat_109": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_109"}, + "feat_110": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_110"}, + "feat_111": {"dtype": tf.int32, "shape": [BATCH_SIZE, 36], "name": "feat_111"}, + "feat_112": {"dtype": tf.int32, "shape": [BATCH_SIZE, 10], "name": "feat_112"}, + "feat_113": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_113"}, + "feat_114": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_114"}, + "feat_115": {"dtype": tf.float32, "shape": [BATCH_SIZE, 60], "name": "feat_115"}, + "feat_116": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_116"}, + "feat_117": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_117"}, + "feat_118": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_118"}, + "feat_119": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_119"}, + "feat_120": {"dtype": tf.int32, "shape": [BATCH_SIZE, 13], "name": "feat_120"}, + "feat_121": {"dtype": tf.int32, "shape": [BATCH_SIZE, 3], "name": "feat_121"}, + "feat_122": {"dtype": tf.int32, "shape": [BATCH_SIZE, 9], "name": "feat_122"}, + "feat_123": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_123"}, + "feat_124": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_124"}, + "feat_125": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_125"}, + "feat_126": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_126"}, + "feat_127": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_127"}, + "feat_128": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_128"}, + "feat_129": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_129"}, + "feat_130": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_130"}, + "feat_131": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_131"}, + "feat_132": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_132"}, + "feat_133": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_133"}, + "feat_134": {"dtype": tf.int32, "shape": [BATCH_SIZE, 10], "name": "feat_134"}, + "feat_135": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_135"}, + "feat_136": {"dtype": tf.int32, "shape": [BATCH_SIZE, 33], "name": "feat_136"}, + "feat_137": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_137"}, + "feat_138": {"dtype": tf.int32, "shape": [BATCH_SIZE, 36], "name": "feat_138"}, + "feat_139": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_139"}, + "feat_140": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_140"}, + "feat_141": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_141"}, + "feat_142": {"dtype": tf.int32, "shape": [BATCH_SIZE, 26], "name": "feat_142"}, + "feat_143": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_143"}, + "feat_144": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_144"}, + "feat_145": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_145"}, + "feat_146": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_146"}, + "feat_147": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_147"}, + "feat_148": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_148"}, + "feat_149": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_149"}, + "feat_150": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_150"}, + "feat_151": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_151"}, + "feat_152": {"dtype": tf.float32, "shape": [BATCH_SIZE, 7], "name": "feat_152"}, + "feat_153": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_153"}, + "feat_154": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_154"}, + "feat_155": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_155"}, + "feat_156": {"dtype": tf.float32, "shape": [BATCH_SIZE, 8], "name": "feat_156"}, + "feat_157": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_157"}, + "feat_158": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_158"}, + "feat_159": {"dtype": tf.int32, "shape": [BATCH_SIZE, 8], "name": "feat_159"}, + "feat_160": {"dtype": tf.float32, "shape": [BATCH_SIZE, 40], "name": "feat_160"}, + "feat_161": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_161"}, + "feat_162": {"dtype": tf.int32, "shape": [BATCH_SIZE, 36], "name": "feat_162"}, + "feat_163": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_163"}, + "feat_164": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_164"}, + "feat_165": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_165"}, + "feat_166": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_166"}, + "feat_167": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_167"}, + "feat_168": {"dtype": tf.int32, "shape": [BATCH_SIZE, 6], "name": "feat_168"}, + "feat_169": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_169"}, + "feat_170": {"dtype": tf.int32, "shape": [BATCH_SIZE, 40], "name": "feat_170"}, + "feat_172": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_172"}, + "feat_173": {"dtype": tf.int32, "shape": [BATCH_SIZE, 1], "name": "feat_173"}, +} diff --git a/examples/rec_infer/optimize/0001-Performance-optimization-referrence.patch b/examples/rec_infer/optimize/0001-Performance-optimization-referrence.patch new file mode 100644 index 0000000000000000000000000000000000000000..a35760558cf75f7525d254f42858a6257881f22a --- /dev/null +++ b/examples/rec_infer/optimize/0001-Performance-optimization-referrence.patch @@ -0,0 +1,72 @@ +--- + tensorflow_serving/model_servers/BUILD | 1 + + tensorflow_serving/model_servers/main.cc | 8 +++++++- + tensorflow_serving/model_servers/server.cc | 5 +++++ + tensorflow_serving/model_servers/server.h | 6 +++++- + 4 files changed, 18 insertions(+), 2 deletions(-) + +diff --git a/tensorflow_serving/model_servers/BUILD b/tensorflow_serving/model_servers/BUILD +index f60f3d7..e74a514 100644 +--- a/tensorflow_serving/model_servers/BUILD ++++ b/tensorflow_serving/model_servers/BUILD +@@ -373,6 +373,7 @@ cc_binary( + deps = [ + ":tensorflow_model_server_main_lib", + ], ++ linkops = ["-L/usr/local/lib -lstringlib", "-L/usr/local/lib -ljemalloc"] + ) + + py_test( +diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc +index 2b83500..3a055d0 100644 +--- a/tensorflow_serving/model_servers/main.cc ++++ b/tensorflow_serving/model_servers/main.cc +@@ -192,7 +192,13 @@ int main(int argc, char** argv) { + "EXPERIMENTAL; CAN BE REMOVED ANYTIME! Load and use " + "TensorFlow Lite model from `model.tflite` file in " + "SavedModel directory instead of the TensorFlow model " +- "from `saved_model.pb` file.")}; ++ "from `saved_model.pb` file."), ++ tensorflow::Flag("set_SyncServerOption_flag", &options.set_SyncServerOption_flag, ++ "if true,the server will config SyncServerOption"), ++ tensorflow::Flag("NUM_CQS", &options.NUM_CQS, "config NUM_CQS"), ++ tensorflow::Flag("MIN_POLLERS", &options.MIN_POLLERS, "config MIN_POLLERS"), ++ tensorflow::Flag("MAX_POLLERS", &options.MAX_POLLERS, "config MAX_POLLERS"), ++ }; + + const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list); + if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { +diff --git a/tensorflow_serving/model_servers/server.cc b/tensorflow_serving/model_servers/server.cc +index 9808f9a..b5df129 100644 +--- a/tensorflow_serving/model_servers/server.cc ++++ b/tensorflow_serving/model_servers/server.cc +@@ -330,6 +330,11 @@ Status Server::BuildAndStart(const Options& server_options) { + BuildServerCredentialsFromSSLConfigFile( + server_options.ssl_config_file)); + } ++ if (server_options.set_SyncServerOption_flag) { ++ builder.SetSyncServerOption(::grpc::ServerBuilder::SyncServerOption.NUM_CQS, server_options.NUM_CQS); ++ builder.SetSyncServerOption(::grpc::ServerBuilder::SyncServerOption.MIN_POLLERS, server_options.MIN_POLLERS); ++ builder.SetSyncServerOption(::grpc::ServerBuilder::SyncServerOption.MAX_POLLERS, server_options.MAX_POLLERS); ++ } + builder.RegisterService(model_service_.get()); + builder.RegisterService(prediction_service_.get()); + builder.SetMaxMessageSize(tensorflow::kint32max); +diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h +index 7738f29..90a0994 100644 +--- a/tensorflow_serving/model_servers/server.h ++++ b/tensorflow_serving/model_servers/server.h +@@ -83,7 +83,11 @@ class Server { + bool enforce_session_run_timeout = true; + bool remove_unused_fields_from_bundle_metagraph = true; + bool use_tflite_model = false; +- ++ // SyncServerOption config ++ bool set_SyncServerOption_flag = false; ++ tensorflow::int32 NUM_CQS = 3; ++ tensorflow::int32 MIN_POLLERS = 6; ++ tensorflow::int32 MAX_POLLERS = 12; + Options(); + }; + +-- diff --git a/examples/rec_infer/optimize/README.md b/examples/rec_infer/optimize/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a6d7cd35b71dc3da5d0f883c09c0d228c0e43284 --- /dev/null +++ b/examples/rec_infer/optimize/README.md @@ -0,0 +1,51 @@ +# 链接ARM的optimized-routines库 +在memcpy等接口占比较大的模型中,有性能收益,源码路径为(https://github.com/ARM-software/optimized-routines/tree/v23.01) +```shell +unzip optimized-routines-23.01.zip +cd optimized-routines-23.01 +``` + +在源码基础上,修改代码,修改脚本如下: +```shell +for m in memcmp memcpy memset memmove memrchr strcpy strchrnul strchr strcmp stpcpy strncmp strnlen strrchr; do + for f in $(grep __${m}_aarch64 * -r |awk -F ':' '{print $1}'); do + sed_str1="__${m}_aarch64" + sed_str2="${m}" + sed -i 's!'${sed_str1}'!'${sed_str2}'!g' $f + done +done +``` + +编译: +```shell +make ARCH=aarch64 -j 8 +cp build/lib/libstringlib.so /usr/local/lib/ +``` + +在编译tensorflow serving时链接libstringlib.so,相关修改代码参考0001-Performance-optimization-referrence +运行server时,需要配置环境变量: +```shell +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +``` + +# 链接jemalloc库 +源码下载链接: https://github.com/jemalloc/jemalloc/archive/refs/tags/5.3.0.tar.gz +编译安装命令如下: +```shell +tar -xzvf jemalloc-5.3.0.tar.gz +cd jemalloc-5.3.0 +./autogen.sh +make -j 8 +make install +``` + +安装完成后,默认安装在/usr/local/lib/,在编译tensorflow serving时链接libjemalloc.so,相关修改代码参考0001-Performance-optimization-referrence +运行server时,需要配置环境变量: +```shell +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +``` + +# gRPC配置优化 +增加NUM_CQS,MIN_POLLERS,MAX_POLLERS这三个配置项的配置,在多线程请求推理场景可以提升性能 +配置项参考gRPC官网(https://grpc.github.io/grpc/cpp/classgrpc_1_1_server_builder.html) +具体修改参考0001-Performance-optimization-referrence,配置最优值根据不同模型和机器可能有所不同; diff --git a/examples/rec_infer/server.sh b/examples/rec_infer/server.sh new file mode 100644 index 0000000000000000000000000000000000000000..67166f618bff973e6d32e376a480b08b832bceaf --- /dev/null +++ b/examples/rec_infer/server.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +# Description: startup server + +taskset -c 0-32 /home/lmp/serving-1.15.0/bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server \ + --model_name=saved_model \ + --model_base_path=$(pwd)/inference_model/saved_model/ \ + --port=9999 \ + --rest_api_prot=9991 \ + --platform_config_file=test.cfg \ No newline at end of file diff --git a/examples/xDeepFM/IO/base_cache.py b/examples/xDeepFM/IO/base_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d8b5fc1ed9f20e347fcb74abc9d9e978eb11a3 --- /dev/null +++ b/examples/xDeepFM/IO/base_cache.py @@ -0,0 +1,15 @@ +"""define abstract base class""" +from npu_bridge.npu_init import * +import abc + +__all__ = ["BaseCache"] + + +class BaseCache(object): + """abstract base class""" + + @abc.abstractmethod + def write_tfrecord(self, infile, outfile, hparams): + """Subclass must implement this.""" + pass + diff --git a/examples/xDeepFM/IO/ffm_cache.py b/examples/xDeepFM/IO/ffm_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..9a694d8f4a1a7799823597549ae911940873feae --- /dev/null +++ b/examples/xDeepFM/IO/ffm_cache.py @@ -0,0 +1,164 @@ +"""define FfmCache class for cache the format dataset""" +from npu_bridge.npu_init import * +from IO.base_cache import BaseCache +import tensorflow as tf +import numpy as np +from collections import defaultdict +import utils.util as util + +__all__ = ["FfmCache"] + + +class FfmCache(BaseCache): + # field index start by 1, feat index start by 1 + def _load_batch_data_from_file(self, file, hparams): + batch_size = hparams.batch_size + labels = [] + features = [] + impression_id = [] + cnt = 0 + with open(file, 'r') as rd: + while True: + line = rd.readline().strip(' ') + if not line: + break + tmp = line.strip().split(util.USER_ID_SPLIT) + if len(tmp) == 2: + impression_id.append(tmp[1].strip()) + line = tmp[0] + cols = line.strip().split(' ') + label = float(cols[0].strip()) + if label > 0: + label = 1 + else: + label = 0 + cur_feature_list = [] + for word in cols[1:]: + if not word.strip(): + continue + tokens = word.strip().split(':') + cur_feature_list.append( \ + [int(tokens[0]) - 1, \ + int(tokens[1]) - 1, \ + float(tokens[2])]) + features.append(cur_feature_list) + labels.append(label) + cnt += 1 + if cnt == batch_size: + yield labels, features, impression_id + labels = [] + features = [] + impression_id = [] + cnt = 0 + if cnt > 0: + yield labels, features, impression_id + + def _convert_data(self, labels, features, hparams): + dim = hparams.FEATURE_COUNT + FIELD_COUNT = hparams.FIELD_COUNT + instance_cnt = len(labels) + + fm_feat_indices = [] + fm_feat_values = [] + fm_feat_shape = [instance_cnt, dim] + + dnn_feat_indices = [] + dnn_feat_values = [] + dnn_feat_weights = [] + dnn_feat_shape = [instance_cnt * FIELD_COUNT, -1] + + for i in range(instance_cnt): + m = len(features[i]) + dnn_feat_dic = {} + for j in range(m): + fm_feat_indices.append([i, features[i][j][1]]) + fm_feat_values.append(features[i][j][2]) + if features[i][j][0] not in dnn_feat_dic: + dnn_feat_dic[features[i][j][0]] = 0 + else: + dnn_feat_dic[features[i][j][0]] += 1 + dnn_feat_indices.append([i * FIELD_COUNT + features[i][j][0], \ + dnn_feat_dic[features[i][j][0]]]) + dnn_feat_values.append(features[i][j][1]) + dnn_feat_weights.append(features[i][j][2]) + if dnn_feat_shape[1] < dnn_feat_dic[features[i][j][0]]: + dnn_feat_shape[1] = dnn_feat_dic[features[i][j][0]] + dnn_feat_shape[1] += 1 + + sorted_index = sorted(range(len(dnn_feat_indices)), + key=lambda k: (dnn_feat_indices[k][0], \ + dnn_feat_indices[k][1])) + + res = {} + res['fm_feat_indices'] = np.asarray(fm_feat_indices, dtype=np.int64) + res['fm_feat_values'] = np.asarray(fm_feat_values, dtype=np.float32) + res['fm_feat_shape'] = np.asarray(fm_feat_shape, dtype=np.int64) + res['labels'] = np.asarray([[label] for label in labels], dtype=np.float32) + + res['dnn_feat_indices'] = np.asarray(dnn_feat_indices, dtype=np.int64)[sorted_index] + res['dnn_feat_values'] = np.asarray(dnn_feat_values, dtype=np.int64)[sorted_index] + res['dnn_feat_weights'] = np.asarray(dnn_feat_weights, dtype=np.float32)[sorted_index] + res['dnn_feat_shape'] = np.asarray(dnn_feat_shape, dtype=np.int64) + return res + + def write_tfrecord(self, infile, outfile, hparams): + sample_num = 0 + FEATURE_COUNT = hparams.FEATURE_COUNT + writer = tf.python_io.TFRecordWriter(outfile) + feature_cnt = defaultdict(lambda: 0) + impression_id_list = [] + try: + for labels, features, impression_id in self._load_batch_data_from_file(infile, hparams): + impression_id_list.extend(impression_id) + sample_num += len(labels) + input_in_sp = self._convert_data(labels, features, hparams) + fm_feat_indices = input_in_sp['fm_feat_indices'] + + for feat in fm_feat_indices: + feature_cnt[feat[1]] += 1 + + fm_feat_values = input_in_sp['fm_feat_values'] + fm_feat_shape = input_in_sp['fm_feat_shape'] + labels = input_in_sp['labels'] + dnn_feat_indices = input_in_sp['dnn_feat_indices'] + dnn_feat_values = input_in_sp['dnn_feat_values'] + dnn_feat_weights = input_in_sp['dnn_feat_weights'] + dnn_feat_shape = input_in_sp['dnn_feat_shape'] + + fm_feat_indices_str = fm_feat_indices.tostring() + labels_str = labels.tostring() + dnn_feat_indices_str = dnn_feat_indices.tostring() + + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'fm_feat_indices': tf.train.Feature( + bytes_list=tf.train.BytesList(value=[fm_feat_indices_str])), + 'fm_feat_values': tf.train.Feature( + float_list=tf.train.FloatList(value=fm_feat_values)), + 'fm_feat_shape': tf.train.Feature( + int64_list=tf.train.Int64List(value=fm_feat_shape)), + 'labels': tf.train.Feature( + bytes_list=tf.train.BytesList(value=[labels_str])), + 'dnn_feat_indices': tf.train.Feature( + bytes_list=tf.train.BytesList(value=[dnn_feat_indices_str])), + 'dnn_feat_values': tf.train.Feature( + int64_list=tf.train.Int64List(value=dnn_feat_values)), + 'dnn_feat_weights': tf.train.Feature( + float_list=tf.train.FloatList(value=dnn_feat_weights)), + 'dnn_feat_shape': tf.train.Feature( + int64_list=tf.train.Int64List(value=dnn_feat_shape)) + } + ) + ) + serialized = example.SerializeToString() + writer.write(serialized) + except: + raise ValueError('train data format must be libffm, for example 1 2:1:0.1 2:3:0.2 3:4:0.4') + writer.close() + sort_feature_cnt = sorted(feature_cnt.items(), key=lambda x: x[0]) + with open(util.FEAT_COUNT_FILE, 'w') as f: + for item in sort_feature_cnt: + f.write(str(item[0]) + ',' + str(item[1]) + '\n') + return sample_num, impression_id_list + diff --git a/examples/xDeepFM/IO/iterator.py b/examples/xDeepFM/IO/iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6a9145e364b213480a430d76134094170105d4 --- /dev/null +++ b/examples/xDeepFM/IO/iterator.py @@ -0,0 +1,208 @@ +"""define iterator""" +from npu_bridge.npu_init import * +import collections +import tensorflow as tf +import abc + +BUFFER_SIZE = 256 +__all__ = ["BaseIterator", "FfmIterator", "DinIterator", "CCCFNetIterator"] + + +class BaseIterator(object): + @abc.abstractmethod + def get_iterator(self, src_dataset): + """Subclass must implement this.""" + pass + + @abc.abstractmethod + def parser(self, record): + pass + + +class FfmIterator(BaseIterator): + def __init__(self, src_dataset): + self.get_iterator(src_dataset) + + def get_iterator(self, src_dataset): + src_dataset = src_dataset.map(self.parser) + # src_dataset = src_dataset.shuffle(buffer_size=BUFFER_SIZE) + iterator = src_dataset.make_initializable_iterator() + batch = iterator.get_next() + self.initializer = iterator.initializer + self.fm_feat_indices = batch.get('fm_feat_indices') + self.fm_feat_values = batch.get('fm_feat_values') + self.fm_feat_shape = batch.get('fm_feat_shape') + self.labels = batch.get('labels') + self.dnn_feat_indices = batch.get('dnn_feat_indices') + self.dnn_feat_values = batch.get('dnn_feat_values') + self.dnn_feat_weights = batch.get('dnn_feat_weights') + self.dnn_feat_shape = batch.get('dnn_feat_shape') + + def parser(self, record): + keys_to_features = { + 'fm_feat_indices': tf.FixedLenFeature([], tf.string), + 'fm_feat_values': tf.VarLenFeature(tf.float32), + 'fm_feat_shape': tf.FixedLenFeature([2], tf.int64), + 'labels': tf.FixedLenFeature([], tf.string), + 'dnn_feat_indices': tf.FixedLenFeature([], tf.string), + 'dnn_feat_values': tf.VarLenFeature(tf.int64), + 'dnn_feat_weights': tf.VarLenFeature(tf.float32), + 'dnn_feat_shape': tf.FixedLenFeature([2], tf.int64), + } + parsed = tf.parse_single_example(record, keys_to_features) + fm_feat_indices = tf.reshape(tf.decode_raw(parsed['fm_feat_indices'], tf.int64), [-1, 2]) + fm_feat_values = tf.sparse_tensor_to_dense(parsed['fm_feat_values']) + fm_feat_shape = parsed['fm_feat_shape'] + labels = tf.reshape(tf.decode_raw(parsed['labels'], tf.float32), [-1, 1]) + dnn_feat_indices = tf.reshape(tf.decode_raw(parsed['dnn_feat_indices'], tf.int64), [-1, 2]) + dnn_feat_values = tf.sparse_tensor_to_dense(parsed['dnn_feat_values']) + dnn_feat_weights = tf.sparse_tensor_to_dense(parsed['dnn_feat_weights']) + dnn_feat_shape = parsed['dnn_feat_shape'] + return { + 'fm_feat_indices': fm_feat_indices, 'fm_feat_values': fm_feat_values, 'fm_feat_shape': fm_feat_shape, + 'labels': labels, 'dnn_feat_indices': dnn_feat_indices, 'dnn_feat_values': dnn_feat_values, + 'dnn_feat_weights': dnn_feat_weights, 'dnn_feat_shape': dnn_feat_shape + } + + +class DinIterator(BaseIterator): + def __init__(self, src_dataset): + self.get_iterator(src_dataset) + + def get_iterator(self, src_dataset): + src_dataset = src_dataset.map(self.parser) + # src_dataset = src_dataset.shuffle(buffer_size=BUFFER_SIZE) + iterator = src_dataset.make_initializable_iterator() + output = iterator.get_next() + (_attention_news_indices, _attention_news_values, _attention_news_shape, \ + _attention_user_indices, _attention_user_values, _attention_user_weights, \ + _attention_user_shape, _fm_feat_indices, _fm_feat_val, \ + _fm_feat_shape, _labels, _dnn_feat_indices, _dnn_feat_values, \ + _dnn_feat_weight, _dnn_feat_shape) = output + self.initializer = iterator.initializer + self.attention_news_indices = _attention_news_indices + self.attention_news_values = _attention_news_values + self.attention_news_shape = _attention_news_shape + self.attention_user_indices = _attention_user_indices + self.attention_user_values = _attention_user_values + self.attention_user_weights = _attention_user_weights + self.attention_user_shape = _attention_user_shape + self.fm_feat_indices = _fm_feat_indices + self.fm_feat_val = _fm_feat_val + self.fm_feat_shape = _fm_feat_shape + self.labels = _labels + self.dnn_feat_indices = _dnn_feat_indices + self.dnn_feat_values = _dnn_feat_values + self.dnn_feat_weight = _dnn_feat_weight + self.dnn_feat_shape = _dnn_feat_shape + + def parser(self, record): + keys_to_features = { + 'attention_news_indices': tf.FixedLenFeature([], tf.string), + 'attention_news_values': tf.VarLenFeature(tf.float32), + 'attention_news_shape': tf.FixedLenFeature([2], tf.int64), + + 'attention_user_indices': tf.FixedLenFeature([], tf.string), + 'attention_user_values': tf.VarLenFeature(tf.int64), + 'attention_user_weights': tf.VarLenFeature(tf.float32), + 'attention_user_shape': tf.FixedLenFeature([2], tf.int64), + + 'fm_feat_indices': tf.FixedLenFeature([], tf.string), + 'fm_feat_val': tf.VarLenFeature(tf.float32), + 'fm_feat_shape': tf.FixedLenFeature([2], tf.int64), + + 'labels': tf.FixedLenFeature([], tf.string), + + 'dnn_feat_indices': tf.FixedLenFeature([], tf.string), + 'dnn_feat_values': tf.VarLenFeature(tf.int64), + 'dnn_feat_weight': tf.VarLenFeature(tf.float32), + 'dnn_feat_shape': tf.FixedLenFeature([2], tf.int64), + } + parsed = tf.parse_single_example(record, keys_to_features) + + attention_news_indices = tf.reshape(tf.decode_raw(parsed['attention_news_indices'], \ + tf.int64), [-1, 2]) + attention_news_values = tf.sparse_tensor_to_dense(parsed['attention_news_values']) + attention_news_shape = parsed['attention_news_shape'] + + attention_user_indices = tf.reshape(tf.decode_raw(parsed['attention_user_indices'], \ + tf.int64), [-1, 2]) + attention_user_values = tf.sparse_tensor_to_dense(parsed['attention_user_values']) + attention_user_weights = tf.sparse_tensor_to_dense(parsed['attention_user_weights']) + attention_user_shape = parsed['attention_user_shape'] + + fm_feat_indices = tf.reshape(tf.decode_raw(parsed['fm_feat_indices'], \ + tf.int64), [-1, 2]) + fm_feat_val = tf.sparse_tensor_to_dense(parsed['fm_feat_val']) + fm_feat_shape = parsed['fm_feat_shape'] + + labels = tf.reshape(tf.decode_raw(parsed['labels'], tf.float32), [-1, 1]) + + dnn_feat_indices = tf.reshape(tf.decode_raw(parsed['dnn_feat_indices'], \ + tf.int64), [-1, 2]) + dnn_feat_values = tf.sparse_tensor_to_dense(parsed['dnn_feat_values']) + dnn_feat_weight = tf.sparse_tensor_to_dense(parsed['dnn_feat_weight']) + dnn_feat_shape = parsed['dnn_feat_shape'] + return (attention_news_indices, attention_news_values, attention_news_shape, \ + attention_user_indices, attention_user_values, attention_user_weights, \ + attention_user_shape, fm_feat_indices, fm_feat_val, \ + fm_feat_shape, labels, dnn_feat_indices, dnn_feat_values, \ + dnn_feat_weight, dnn_feat_shape) + + +class CCCFNetIterator(BaseIterator): + def __init__(self, src_dataset): + self.get_iterator(src_dataset) + + def get_iterator(self, src_dataset): + src_dataset = src_dataset.map(self.parser) + # src_dataset = src_dataset.shuffle(buffer_size=BUFFER_SIZE) + iterator = src_dataset.make_initializable_iterator() + _labels, _userIds, _itemIds, \ + _user_profiles_indices, _user_profiles_values, _user_profiles_weights, _user_profiles_shape, \ + _item_profiles_indices, _item_profiles_values, _item_profiles_weights, _item_profiles_shape = iterator.get_next() + self.initializer = iterator.initializer + self.labels = _labels + self.userIds = _userIds + self.itemIds = _itemIds + self.user_profiles_indices = _user_profiles_indices + self.user_profiles_values = _user_profiles_values + self.user_profiles_weights = _user_profiles_weights + self.user_profiles_shape = _user_profiles_shape + self.item_profiles_indices = _item_profiles_indices + self.item_profiles_values = _item_profiles_values + self.item_profiles_weights = _item_profiles_weights + self.item_profiles_shape = _item_profiles_shape + + def parser(self, record): + keys_to_features = { + 'labels': tf.FixedLenFeature([], tf.string), + 'userIds': tf.VarLenFeature(tf.int64), + 'itemIds': tf.VarLenFeature(tf.int64), + 'user_profiles_indices': tf.FixedLenFeature([], tf.string), + 'user_profiles_values': tf.VarLenFeature(tf.int64), + 'user_profiles_weights': tf.VarLenFeature(tf.float32), + 'user_profiles_shape': tf.FixedLenFeature([2], tf.int64), + 'item_profiles_indices': tf.FixedLenFeature([], tf.string), + 'item_profiles_values': tf.VarLenFeature(tf.int64), + 'item_profiles_weights': tf.VarLenFeature(tf.float32), + 'item_profiles_shape': tf.FixedLenFeature([2], tf.int64) + } + parsed = tf.parse_single_example(record, keys_to_features) + labels = tf.reshape(tf.decode_raw(parsed['labels'], tf.float32), [-1, 1]) + userIds = tf.sparse_tensor_to_dense(parsed['userIds']) + itemIds = tf.sparse_tensor_to_dense(parsed['itemIds']) + + user_profiles_indices = tf.reshape(tf.decode_raw(parsed['user_profiles_indices'], tf.int64), [-1, 2]) + user_profiles_values = tf.sparse_tensor_to_dense(parsed['user_profiles_values']) + user_profiles_weights = tf.sparse_tensor_to_dense(parsed['user_profiles_weights']) + user_profiles_shape = parsed['user_profiles_shape'] + + item_profiles_indices = tf.reshape(tf.decode_raw(parsed['item_profiles_indices'], tf.int64), [-1, 2]) + item_profiles_values = tf.sparse_tensor_to_dense(parsed['item_profiles_values']) + item_profiles_weights = tf.sparse_tensor_to_dense(parsed['item_profiles_weights']) + item_profiles_shape = parsed['item_profiles_shape'] + + return labels, userIds, itemIds, \ + user_profiles_indices, user_profiles_values, user_profiles_weights, user_profiles_shape, \ + item_profiles_indices, item_profiles_values, item_profiles_weights, item_profiles_shape \ No newline at end of file diff --git a/examples/xDeepFM/README.md b/examples/xDeepFM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d9a937449e2f529aa8dae2b91be880ed9c61cdaf --- /dev/null +++ b/examples/xDeepFM/README.md @@ -0,0 +1,296 @@ +# xDeepFM迁移样例 + +## 模型参考开源链接 + +1. https://github.com/Leavingseason/xDeepFM + +2. Commits on Oct 15, 2018,提交的SHA-1 hash值(提交ID):114c4c45b1cb6144b2540f92a2b357c3f445e98e + +3. 只保留执行所需要的代码及文件,其他已删除。 +4. config/network.yaml配置文件,data/dnn/infer.userid.txt、res/infer.userid.txt等数据文件由用户从开源链接下载导入 + +## 迁移NPU + +请参照昇腾社区CANN商用版文档先使用迁移工具进行NPU自动迁移:https://www.hiascend.com/document/detail/zh/canncommercial/700/modeldev/tfmigr1/tfmigr1_000009.html + + +## 迁移mxRec + +1、修改IO/iterator.py,把第30~41行 + + +```python + _fm_feat_indices, _fm_feat_values, + _fm_feat_shape, _labels, _dnn_feat_indices, + _dnn_feat_values, _dnn_feat_weights, _dnn_feat_shape = iterator.get_next() + self.initializer = iterator.initializer + self.fm_feat_indices = _fm_feat_indices + self.fm_feat_values = _fm_feat_values + self.fm_feat_shape = _fm_feat_shape + self.labels = _labels + self.dnn_feat_indices = _dnn_feat_indices + self.dnn_feat_values = _dnn_feat_values + self.dnn_feat_weights = _dnn_feat_weights + self.dnn_feat_shape = _dnn_feat_shape +``` +` ` ` `改为: +```python + batch = iterator.get_next() + self.initializer = iterator.initializer + self.fm_feat_indices = batch.get('fm_feat_indices') + self.fm_feat_values = batch.get('fm_feat_values') + self.fm_feat_shape = batch.get('fm_feat_shape') + self.labels = batch.get('labels') + self.dnn_feat_indices = batch.get('dnn_feat_indices') + self.dnn_feat_values = batch.get('dnn_feat_values') + self.dnn_feat_weights = batch.get('dnn_feat_weights') + self.dnn_feat_shape = batch.get('dnn_feat_shape') +``` + +` ` ` `第63~65行 +```python + return fm_feat_indices, fm_feat_values, + fm_feat_shape, labels, dnn_feat_indices, + dnn_feat_values, dnn_feat_weights, dnn_feat_shape +``` +` ` ` `改为: +```python + return { + 'fm_feat_indices': fm_feat_indices, 'fm_feat_values': fm_feat_values, 'fm_feat_shape': fm_feat_shape, + 'labels': labels, 'dnn_feat_indices': dnn_feat_indices, 'dnn_feat_values': dnn_feat_values, + 'dnn_feat_weights': dnn_feat_weights, 'dnn_feat_shape': dnn_feat_shape + } +``` + +2、修改src/base_model.py。把embedding初始化值设成tf.zeros_initializer(),把84行 +```python + return tf.truncated_normal_initializer(stddev=hparams.init_value) +``` +` ` ` `改为(为了对比CPU,xDeepFM源代码这里也要一起修改): +```python + return tf.zeros_initializer() +``` + +` ` ` `更新自动改图模式下生成新数据集中batch的label记录,把188~189行 +```python + def eval(self, sess): + return sess.run([self.loss, self.data_loss, self.pred, self.iterator.labels], \ +``` +` ` ` `改为: +```python + def eval(self, sess, eval_label): + return sess.run([self.loss, self.data_loss, self.pred, eval_label], \ +``` + +3、修改src/exDeepFM.py。在第6行添加 +```python +from mx_rec.core.embedding import create_table +from mx_rec.core.embedding import sparse_lookup +``` +` ` ` `把40~43行 +```python + w_fm_nn_input_orgin = tf.nn.embedding_lookup_sparse(self.embedding, + fm_sparse_index, + fm_sparse_weight, + combiner="sum") +``` +` ` ` `改为: +```python + dense_indices = tf.sparse.to_dense(fm_sparse_index, default_value=0) + dense_weights = tf.sparse.to_dense(fm_sparse_weight, default_value=0) + + sparse_hashtable = create_table(key_dtype=tf.int32, + dim=tf.TensorShape([hparams.dim]), + name='sparse_embeddings_table', + emb_initializer=tf.zeros_initializer(), + device_vocabulary_size=hparams.FEATURE_COUNT, + host_vocabulary_size=0 + ) + embedded_values = sparse_lookup(sparse_hashtable, + dense_indices, + is_train=True, + name="sparse_embeddings", + modify_graph=True) + w_fm_nn_input_orgin = tf.reduce_sum(embedded_values * tf.expand_dims(dense_weights, axis=-1), axis=1) +``` + +4、修改main.py。在第176行添加 +```python + # init + from mx_rec.util.initialize import init + init(use_dynamic=True, + use_dynamic_expansion=False) +``` + +5、修改train.py。把第35~57行 +```python + graph = tf.Graph() + with graph.as_default(): + # feed train file name, valid file name, or test file name + filenames = tf.placeholder(tf.string, shape=[None]) + #src_dataset = tf.contrib.data.TFRecordDataset(filenames) + src_dataset = tf.data.TFRecordDataset(filenames) + + if hparams.data_format == 'ffm': + batch_input = FfmIterator(src_dataset) + elif hparams.data_format == 'din': + batch_input = DinIterator(src_dataset) + elif hparams.data_format == 'cccfnet': + batch_input = CCCFNetIterator(src_dataset) + else: + raise ValueError("not support {0} format data".format(hparams.data_format)) + # build model + model = model_creator( + hparams, + iterator=batch_input, + scope=scope) + + return TrainModel( + graph=graph, +``` +` ` ` `改为: +```python + # feed train file name, valid file name, or test file name + filenames = tf.placeholder(tf.string, shape=[None]) + # src_dataset = tf.contrib.data.TFRecordDataset(filenames) + src_dataset = tf.data.TFRecordDataset(filenames) + + if hparams.data_format == 'ffm': + batch_input = FfmIterator(src_dataset) + elif hparams.data_format == 'din': + batch_input = DinIterator(src_dataset) + elif hparams.data_format == 'cccfnet': + batch_input = CCCFNetIterator(src_dataset) + else: + raise ValueError("not support {0} format data".format(hparams.data_format)) + # build model + model = model_creator( + hparams, + iterator=batch_input, + scope=scope) + + return TrainModel( + graph=tf.get_default_graph(), +``` +` ` ` `把第68~73行 +```python + load_sess.run(load_model.iterator.initializer, feed_dict={load_model.filenames: [filename]}) + preds = [] + labels = [] + while True: + try: + _, _, step_pred, step_labels = load_model.model.eval(load_sess) +``` +` ` ` `改为: +```python + from mx_rec.util.initialize import ConfigInitializer + eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(True).get("labels") + initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True) + load_sess.run(initializer, feed_dict={load_model.filenames: [filename]}) + preds = [] + labels = [] + while True: + try: + _, _, step_pred, step_labels = load_model.model.eval(load_sess, eval_label) +``` + +` ` ` `在第223行添加 +```python + from mx_rec.graph.modifier import modify_graph_and_start_emb_cache + modify_graph_and_start_emb_cache(dump_graph=True) +``` +` ` ` `把第239行 +```python + train_sess.run(train_model.iterator.initializer, feed_dict={train_model.filenames: [hparams.train_file_cache]}) +``` +` ` ` `改为: +```python + from mx_rec.util.initialize import ConfigInitializer + initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True) + train_sess.run(initializer, feed_dict={train_model.filenames: [hparams.train_file_cache]}) +``` +6、为了适配mxRec运行环境,添加了run.sh。 + +## 适配其他代码 + +1、修改utils/util.py。把第63行 + + +```python + config = yaml.load(f) +``` +` ` ` `改为(为了xDeepFM源代码在CPU上能跑通,这里也要一起修改): +```python + config = yaml.safe_load(f) +``` + +2、由于去掉了无关代码src/CIN.py,修改main.py适配。把第156~158行 + +```python + 'opnn', 'fm', 'lr', 'din', 'cccfnet', 'deepcross', 'exDeepFM', "cross", "CIN"]: + raise ValueError( + "model type must be cccfnet, deepFM, deepWide, dnn, ipnn, opnn, fm, lr, din, deepcross, exDeepFM, cross, CIN but you set is {0}".format( +``` +` ` ` `改为: +```python + 'opnn', 'fm', 'lr', 'din', 'cccfnet', 'deepcross', 'exDeepFM', "cross"]: + raise ValueError( + "model type must be cccfnet, deepFM, deepWide, dnn, ipnn, opnn, fm, lr, din, deepcross, exDeepFM, " + "cross, but you set is {0}".format(config['model']['model_type'])) +``` + +` ` ` `修改train.py适配。删除第21行代码 +```python +from src.CIN import CINModel +``` + +` ` ` `删除第210~212行代码 +```python + elif hparams.model_type == 'CIN': + print("run extreme cin model!") + model_creator = CINModel +``` + +## 运行命令 +```shell +bash run.sh main.py 10.10.10.10 +``` +其中,10.10.10.10为服务器IP,请替换成对应服务器IP。 + +## 验证结果 +1、CPU: +```log +step 1 , total_loss: 0.6931, data_loss: 0.6931 +step 2 , total_loss: 0.6905, data_loss: 0.6905 +finish one epoch! +at epoch 0 train info: loss:0.6918214857578278 eval info: auc:0.4867, logloss:0.6865 test info: auc:0.4867, logloss:0.6865 +at epoch 0 , train time: 0.6 eval time: 0.3 +step 1 , total_loss: 0.6845, data_loss: 0.6845 +step 2 , total_loss: 0.6818, data_loss: 0.6818 +finish one epoch! +at epoch 1 train info: loss:0.6831814646720886 eval info: auc:0.485, logloss:0.6801 test info: auc:0.485, logloss:0.6801 +at epoch 1 , train time: 0.2 eval time: 0.1 +step 1 , total_loss: 0.6766, data_loss: 0.6766 +step 2 , total_loss: 0.6732, data_loss: 0.6732 +finish one epoch! +at epoch 2 train info: loss:0.6748818755149841 eval info: auc:0.4832, logloss:0.6738 test info: auc:0.4832, logloss:0.6738 +at epoch 2 , train time: 0.1 eval time: 0.1 +``` +2、mxRec: +```log +[1,0]:step 1 , total_loss: 0.6931, data_loss: 0.6931 +[1,0]:step 2 , total_loss: 0.6905, data_loss: 0.6905 +[1,0]:finish one epoch! +[1,0]:at epoch 0 train info: loss:0.6918215453624725 eval info: auc:0.4867, logloss:0.6865 test info: auc:0.4867, logloss:0.6865 +[1,0]:at epoch 0 , train time: 15.9 eval time: 3.1 +[1,0]:step 1 , total_loss: 0.6845, data_loss: 0.6845 +[1,0]:step 2 , total_loss: 0.6818, data_loss: 0.6818 +[1,0]:finish one epoch! +[1,0]:at epoch 1 train info: loss:0.6831814646720886 eval info: auc:0.485, logloss:0.6801 test info: auc:0.485, logloss:0.6801 +[1,0]:at epoch 1 , train time: 7.8 eval time: 0.7 +[1,0]:step 1 , total_loss: 0.6766, data_loss: 0.6766 +[1,0]:step 2 , total_loss: 0.6732, data_loss: 0.6732 +[1,0]:finish one epoch! +[1,0]:at epoch 2 train info: loss:0.6748818457126617 eval info: auc:0.4832, logloss:0.6738 test info: auc:0.4832, logloss:0.6738 +[1,0]:at epoch 2 , train time: 0.5 eval time: 0.7 +``` diff --git a/examples/xDeepFM/main.py b/examples/xDeepFM/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0752d18bcd32b16841fc7cec4d0162c19ec9fca5 --- /dev/null +++ b/examples/xDeepFM/main.py @@ -0,0 +1,195 @@ +"""This script parse and run train function""" +from npu_bridge.npu_init import * +import train +import utils.util as util +import tensorflow as tf +import sys +from utils.log import Log + +#yaml = sys.argv[1] + + + +def flat_config(config): + """flat config to a dict""" + f_config = {} + category = ['data', 'model', 'train', 'info'] + for cate in category: + for key, val in config[cate].items(): + f_config[key] = val + return f_config + + +def create_hparams(FLAGS): + """Create hparams.""" + FLAGS = flat_config(FLAGS) + return tf.contrib.training.HParams( + # data + train_file=FLAGS['train_file'] if 'train_file' in FLAGS else None, + eval_file=FLAGS['eval_file'] if 'eval_file' in FLAGS else None, + test_file=FLAGS['test_file'] if 'test_file' in FLAGS else None, + infer_file=FLAGS['infer_file'] if 'infer_file' in FLAGS else None, + FEATURE_COUNT=FLAGS['FEATURE_COUNT'] if 'FEATURE_COUNT' in FLAGS else None, + FIELD_COUNT=FLAGS['FIELD_COUNT'] if 'FIELD_COUNT' in FLAGS else None, + data_format=FLAGS['data_format'] if 'data_format' in FLAGS else None, + PAIR_NUM=FLAGS['PAIR_NUM'] if 'PAIR_NUM' in FLAGS else None, + DNN_FIELD_NUM=FLAGS['DNN_FIELD_NUM'] if 'DNN_FIELD_NUM' in FLAGS else None, + n_user=FLAGS['n_user'] if 'n_user' in FLAGS else None, + n_item=FLAGS['n_item'] if 'n_item' in FLAGS else None, + n_user_attr=FLAGS['n_user_attr'] if 'n_user_attr' in FLAGS else None, + n_item_attr=FLAGS['n_item_attr'] if 'n_item_attr' in FLAGS else None, + # model + dim=FLAGS['dim'] if 'dim' in FLAGS else None, + layer_sizes=FLAGS['layer_sizes'] if 'layer_sizes' in FLAGS else None, + cross_layer_sizes=FLAGS['cross_layer_sizes'] if 'cross_layer_sizes' in FLAGS else None, + cross_layers = FLAGS['cross_layers'] if 'cross_layers' in FLAGS else None, + activation=FLAGS['activation'] if 'activation' in FLAGS else None, + cross_activation=FLAGS['cross_activation'] if 'cross_activation' in FLAGS else "identity", + dropout=FLAGS['dropout'] if 'dropout' in FLAGS else None, + attention_layer_sizes=FLAGS['attention_layer_sizes'] if 'attention_layer_sizes' in FLAGS else None, + attention_activation=FLAGS['attention_activation'] if 'attention_activation' in FLAGS else None, + model_type=FLAGS['model_type'] if 'model_type' in FLAGS else None, + method=FLAGS['method'] if 'method' in FLAGS else None, + load_model_name=FLAGS['load_model_name'] if 'load_model_name' in FLAGS else None, + mu=FLAGS['mu'] if 'mu' in FLAGS else None, + # train + init_method=FLAGS['init_method'] if 'init_method' in FLAGS else 'tnormal', + init_value=FLAGS['init_value'] if 'init_value' in FLAGS else 0.01, + embed_l2=FLAGS['embed_l2'] if 'embed_l2' in FLAGS else 0.0000, + embed_l1=FLAGS['embed_l1'] if 'embed_l1' in FLAGS else 0.0000, + layer_l2=FLAGS['layer_l2'] if 'layer_l2' in FLAGS else 0.0000, + layer_l1=FLAGS['layer_l1'] if 'layer_l1' in FLAGS else 0.0000, + cross_l2=FLAGS['cross_l2'] if 'cross_l2' in FLAGS else 0.0000, + cross_l1=FLAGS['cross_l1'] if 'cross_l1' in FLAGS else 0.0000, + learning_rate=FLAGS['learning_rate'] if 'learning_rate' in FLAGS else 0.001, + loss=FLAGS['loss'] if 'loss' in FLAGS else None, + optimizer=FLAGS['optimizer'] if 'optimizer' in FLAGS else 'adam', + epochs=FLAGS['epochs'] if 'epochs' in FLAGS else 10, + batch_size=FLAGS['batch_size'] if 'batch_size' in FLAGS else 1, + # show info + log=FLAGS['log'] if 'log' in FLAGS else "log", + logger=None, + show_step=FLAGS['show_step'] if 'show_step' in FLAGS else 1, + save_epoch=FLAGS['save_epoch'] if 'save_epoch' in FLAGS else 5, + metrics=FLAGS['metrics'] if 'metrics' in FLAGS else None + ) + + +def check_type(config): + """check config type""" + # check parameter type + int_parameters = ['FEATURE_COUNT', 'FIELD_COUNT', 'dim', 'epochs', 'batch_size', 'show_step', \ + 'save_epoch', 'PAIR_NUM', 'DNN_FIELD_NUM', 'attention_layer_sizes', \ + 'n_user', 'n_item', 'n_user_attr', 'n_item_attr'] + for param in int_parameters: + if param in config and not isinstance(config[param], int): + raise TypeError("parameters {0} must be int".format(param)) + + float_parameters = ['init_value', 'learning_rate', 'embed_l2', \ + 'embed_l1', 'layer_l2', 'layer_l1', 'mu'] + for param in float_parameters: + if param in config and not isinstance(config[param], float): + raise TypeError("parameters {0} must be float".format(param)) + + str_parameters = ['train_file', 'eval_file', 'test_file', 'infer_file', 'method', \ + 'load_model_name', 'loss', 'optimizer', 'init_method', 'attention_activation'] + for param in str_parameters: + if param in config and not isinstance(config[param], str): + raise TypeError("parameters {0} must be str".format(param)) + + list_parameters = ['layer_sizes', 'activation', 'dropout'] + for param in list_parameters: + if param in config and not isinstance(config[param], list): + raise TypeError("parameters {0} must be list".format(param)) + + if ('data_format' in config) and (not config['data_format'] in ['ffm', 'din', 'cccfnet']): + raise TypeError("parameters data_format must be din" \ + ",ffm, cccfnet but is {0}".format(config['data_format'])) + + +def check_nn_config(config): + """check neural networks config""" + if config['model']['model_type'] in ['fm']: + required_parameters = ['train_file', 'eval_file', 'FEATURE_COUNT', 'dim', 'loss', 'data_format', 'method'] + elif config['model']['model_type'] in ['lr']: + required_parameters = ['train_file', 'eval_file', 'FEATURE_COUNT', 'loss', 'data_format', 'method'] + elif config['model']['model_type'] in ['din']: + required_parameters = ['train_file', 'eval_file', 'PAIR_NUM', 'DNN_FIELD_NUM', 'FEATURE_COUNT', 'dim', \ + 'layer_sizes', 'activation', 'attention_layer_sizes', 'attention_activation', 'loss', \ + 'data_format', 'dropout', 'method'] + elif config['model']['model_type'] in ['cccfnet']: + required_parameters = ['train_file', 'eval_file', 'dim', 'layer_sizes', 'n_user', 'n_item', 'n_user_attr', + 'n_item_attr', + 'activation', 'loss', 'data_format', 'dropout', 'mu', 'method'] + elif config['model']['model_type'] in ['exDeepFM']: + required_parameters = ['train_file', 'eval_file', 'FIELD_COUNT', 'FEATURE_COUNT', 'method', + 'dim', 'layer_sizes', 'cross_layer_sizes', 'activation', 'loss', 'data_format', 'dropout'] + elif config['model']['model_type'] in ['deepcross']: + required_parameters = ['train_file', 'eval_file', 'FIELD_COUNT', 'FEATURE_COUNT', 'method', + 'dim', 'layer_sizes', 'cross_layers', 'activation', 'loss', 'data_format', + 'dropout'] + else: + required_parameters = ['train_file', 'eval_file', 'FIELD_COUNT', 'FEATURE_COUNT', 'method', + 'dim', 'layer_sizes', 'activation', 'loss', 'data_format', 'dropout'] + f_config = flat_config(config) + # check required parameters + for param in required_parameters: + if param not in f_config: + raise ValueError("parameters {0} must be set".format(param)) + if f_config['model_type'] == 'din': + if f_config['data_format'] != 'din': + raise ValueError( + "for din model, data format must be din, but your set is {0}".format(f_config['data_format'])) + elif f_config['model_type'] == 'cccfnet': + if f_config['data_format'] != 'cccfnet': + raise ValueError( + "for cccfnet model, data format must be cccfnet, but your set is {0}".format(f_config['data_format'])) + else: + if f_config['data_format'] != 'ffm': + raise ValueError("data format must be ffm, but your set is {0}".format(f_config['data_format'])) + check_type(f_config) + + +def check_config(config): + """check networks config""" + if config['model']['model_type'] not in ['deepFM', 'deepWide', 'dnn', 'ipnn', \ + 'opnn', 'fm', 'lr', 'din', 'cccfnet', 'deepcross', 'exDeepFM', "cross"]: + raise ValueError( + "model type must be cccfnet, deepFM, deepWide, dnn, ipnn, opnn, fm, lr, din, deepcross, exDeepFM, " + "cross, but you set is {0}".format(config['model']['model_type'])) + check_nn_config(config) + + +# train process load yaml +def load_yaml(): + """load config from yaml""" + yaml_name = util.CONFIG_DIR + util.TRAIN_YAML + print('trainging network configuration file is {0}'.format(yaml_name)) + util.check_file_exist(yaml_name) + config = util.load_yaml_file(yaml_name) + return config + + +def main(): + """main function""" + + # init + from mx_rec.util.initialize import init + init(use_dynamic=True, + use_dynamic_expansion=False) + + # flag = True + util.check_tensorflow_version() + util.check_and_mkdir() + #util.TRAIN_YAML = yaml + config = load_yaml() + check_config(config) + hparams = create_hparams(config) + print(hparams.values()) + log = Log(hparams) + hparams.logger = log.logger + train.train(hparams) + + +main() + diff --git a/examples/xDeepFM/run.sh b/examples/xDeepFM/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..613c440e705b242b86d6a2c557be041b366050a3 --- /dev/null +++ b/examples/xDeepFM/run.sh @@ -0,0 +1,130 @@ +kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 + +# 获取输入参数:py、ip +if [ $# -ge 1 ]; then + py=$1 + ip=$2 +else + echo "for example: bash run.sh main.py 10.10.10.10 or bash run.sh main.py" + exit 1 +fi + +# 检查输入的python文件是否合法 +if [[ $py =~ ^[a-z0-9_]+\.py$ ]]; then + echo "File $py is a valid Python file" +else + echo "File $py is not a Python file" + exit 1 +fi + +# 判断IP地址是否有效 +if [ -n "$ip" ]; then + if [[ $ip =~ ^([0-9]{1,3}\.){3}[0-9]{1,3}$ ]]; then + # 将IP地址拆分成四个数字 + ip_array=(${ip//./ }) + # 判断每个数字是否在0-255之间 + valid=true + for i in "${ip_array[@]}"; do + if ((i < 0 || i > 255)); then + valid=false + break + fi + done + if $valid; then + echo "ip: $ip is valid" + else + echo "ip: $ip is not valid" + exit 1 + fi + else + echo "ip: $ip is not valid." + exit 1 + fi +fi + +cur_path=`pwd` +mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config +so_path=${mx_rec_package_path}/libasc +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' +interface="lo" +local_rank_size=1 # 每个节点使用的NPU卡数 +num_server=1 # 训练节点数 +num_process=$((${num_server} * ${local_rank_size})) # 训练总的进程数,等于使用的NPU卡的总数 + +export HCCL_CONNECT_TIMEOUT=1200 # HCCL集合通信 建链超时时间,取值范围[120,7200] +export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 +export LD_PRELOAD=/usr/lib64/libgomp.so.1:/usr/local/python3.7.5/lib/python3.7/site-packages/scikit_learn.libs/libgomp-d22c30c5.so.1.0.0 +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +# 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 +export JOB_ID=10086 +# 训练任务使用的NPU卡数总数 +export MXREC_LOG_LEVEL="ERROR" # 框架日志等级 +export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL +# 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 +export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL +export MXREC_MODE="ASC" +export USE_MPI=1 + +# 帮助信息,不需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo "Usage: ./run.sh [OPTION]... [IP]..." + echo " " + echo "parameter explain: + [OPTION] main.py + [IP] IP address of the host + -h/--help show help message + " + exit 1 +fi + +# 使用ranktable方案 +function rankTableSolution() { + echo "The ranktable solution" + export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" + export RANK_SIZE=$num_process + export ASCEND_VISIBLE_DEVICES="0" + export RANK_ID=0 + export ASCEND_DEVICE_ID=$RANK_ID + echo "RANK_TABLE_FILE=$RANK_TABLE_FILE" + if [ ! -f "$RANK_TABLE_FILE" ];then + echo "the rank table file does not exit. Please reference {hccl_json_${local_rank_size}p.json} to correctly config rank table file" + exit 1 + fi +} + +if [ ! -n "$ip" ]; then + rankTableSolution +else + VALID_CHECK=$(echo $ip|awk -F. '$1<=255&&$2<=255&&$3<=255&&$4<=255{print "yes"}') + if echo $ip|grep -E "^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$">/dev/null; then + if [ "$VALID_CHECK" == "yes" ]; then + #################使用去除ranktable方案时开启###################### + echo "ip: $ip available." + echo "The ranktable solution is removed." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=6000 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" + echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" + ######################################################### + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi +fi + +echo "use horovod to start tasks" +DATE=$(date +%Y-%m-%d-%H-%M-%S) +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee "temp_${local_rank_size}p_t_${DATE}.log" diff --git a/examples/xDeepFM/src/base_model.py b/examples/xDeepFM/src/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..96682ae6c40c49017c59d57b6b9843cd0a24299f --- /dev/null +++ b/examples/xDeepFM/src/base_model.py @@ -0,0 +1,195 @@ +"""define base class model""" +from npu_bridge.npu_init import * +import abc +import math +import tensorflow as tf +import utils.util as util +from IO.iterator import BaseIterator + +__all__ = ["BaseModel"] + + +class BaseModel(object): + def __init__(self, hparams, iterator, scope=None): + assert isinstance(iterator, BaseIterator) + tf.set_random_seed(1234) + self.iterator = iterator + self.layer_params = [] + self.embed_params = [] + self.cross_params = [] + self.layer_keeps = None + self.keep_prob_train = None + self.keep_prob_test = None + self.initializer = self._get_initializer(hparams) + self.logit = self._build_graph(hparams) + self.pred = self._get_pred(self.logit, hparams) + self.data_loss = self._compute_data_loss(hparams) + self.regular_loss = self._compute_regular_loss(hparams) + self.loss = tf.add(self.data_loss, self.regular_loss) + self.saver = tf.train.Saver(max_to_keep=hparams.epochs) + self.update = self._build_train_opt(hparams) + self.init_op = tf.global_variables_initializer() + self.merged = self._add_summaries() + + def _get_pred(self, logit, hparams): + if hparams.method == 'regression': + pred = tf.identity(logit) + elif hparams.method == 'classification': + pred = tf.sigmoid(logit) + else: + raise ValueError("method must be regression or classification, but now is {0}".format(hparams.method)) + return pred + + def _add_summaries(self): + tf.summary.scalar("data_loss", self.data_loss) + tf.summary.scalar("regular_loss", self.regular_loss) + tf.summary.scalar("loss", self.loss) + merged = tf.summary.merge_all() + return merged + + @abc.abstractmethod + def _build_graph(self, hparams): + """Subclass must implement this.""" + pass + + def _l2_loss(self, hparams): + l2_loss = tf.zeros([1], dtype=tf.float32) + # embedding_layer l2 loss + for param in self.embed_params: + l2_loss = tf.add(l2_loss, tf.multiply(hparams.embed_l2, tf.nn.l2_loss(param))) + params = self.layer_params + for param in params: + l2_loss = tf.add(l2_loss, tf.multiply(hparams.layer_l2, tf.nn.l2_loss(param))) + return l2_loss + + def _l1_loss(self, hparams): + l1_loss = tf.zeros([1], dtype=tf.float32) + # embedding_layer l2 loss + for param in self.embed_params: + l1_loss = tf.add(l1_loss, tf.multiply(hparams.embed_l1, tf.norm(param, ord=1))) + params = self.layer_params + for param in params: + l1_loss = tf.add(l1_loss, tf.multiply(hparams.layer_l1, tf.norm(param, ord=1))) + return l1_loss + + def _cross_l_loss(self, hparams): + cross_l_loss = tf.zeros([1], dtype=tf.float32) + for param in self.cross_params: + cross_l_loss = tf.add(cross_l_loss, tf.multiply(hparams.cross_l1, tf.norm(param, ord=1))) + cross_l_loss = tf.add(cross_l_loss, tf.multiply(hparams.cross_l2, tf.norm(param, ord=1))) + return cross_l_loss + + def _get_initializer(self, hparams): + if hparams.init_method == 'tnormal': + return tf.zeros_initializer() + elif hparams.init_method == 'uniform': + return tf.random_uniform_initializer(-hparams.init_value, hparams.init_value) + elif hparams.init_method == 'normal': + return tf.random_normal_initializer(stddev=hparams.init_value) + elif hparams.init_method == 'xavier_normal': + return tf.contrib.layers.xavier_initializer(uniform=False) + elif hparams.init_method == 'xavier_uniform': + return tf.contrib.layers.xavier_initializer(uniform=True) + elif hparams.init_method == 'he_normal': + return tf.contrib.layers.variance_scaling_initializer( \ + factor=2.0, mode='FAN_IN', uniform=False) + elif hparams.init_method == 'he_uniform': + return tf.contrib.layers.variance_scaling_initializer( \ + factor=2.0, mode='FAN_IN', uniform=True) + else: + return tf.truncated_normal_initializer(stddev=hparams.init_value) + + def _compute_data_loss(self, hparams): + if hparams.loss == 'cross_entropy_loss': + data_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( \ + logits=tf.reshape(self.logit, [-1]), \ + labels=tf.reshape(self.iterator.labels, [-1]))) + elif hparams.loss == 'square_loss': + data_loss = tf.sqrt(tf.reduce_mean( + tf.squared_difference(tf.reshape(self.pred, [-1]), tf.reshape(self.iterator.labels, [-1])))) + elif hparams.loss == 'log_loss': + data_loss = tf.reduce_mean(tf.losses.log_loss(predictions=tf.reshape(self.pred, [-1]), + labels=tf.reshape(self.iterator.labels, [-1]))) + else: + raise ValueError("this loss not defined {0}".format(hparams.loss)) + return data_loss + + def _compute_regular_loss(self, hparams): + regular_loss = self._l2_loss(hparams) + self._l1_loss(hparams) + self._cross_l_loss(hparams) + regular_loss = tf.reduce_sum(regular_loss) + return regular_loss + + def _build_train_opt(self, hparams): + def train_opt(hparams): + if hparams.optimizer == 'adadelta': + train_step = tf.train.AdadeltaOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'adagrad': + train_step = tf.train.AdagradOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'sgd': + train_step = tf.train.GradientDescentOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'adam': + train_step = tf.train.AdamOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'ftrl': + train_step = tf.train.FtrlOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'gd': + train_step = tf.train.GradientDescentOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'padagrad': + train_step = tf.train.ProximalAdagradOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'pgd': + train_step = tf.train.ProximalGradientDescentOptimizer( \ + hparams.learning_rate).minimize(self.loss) + elif hparams.optimizer == 'rmsprop': + train_step = tf.train.RMSPropOptimizer( \ + hparams.learning_rate).minimize(self.loss) + else: + train_step = tf.train.GradientDescentOptimizer( \ + hparams.learning_rate).minimize(self.loss) + return train_step + + train_step = train_opt(hparams) + return train_step + + def _active_layer(self, logit, scope, activation, layer_idx): + logit = self._dropout(logit, layer_idx) + logit = self._activate(logit, activation) + return logit + + def _activate(self, logit, activation): + if activation == 'sigmoid': + return tf.nn.sigmoid(logit) + elif activation == 'softmax': + return tf.nn.softmax(logit) + elif activation == 'relu': + return tf.nn.relu(logit) + elif activation == 'tanh': + return tf.nn.tanh(logit) + elif activation == 'elu': + return tf.nn.elu(logit) + elif activation == 'identity': + return tf.identity(logit) + else: + raise ValueError("this activations not defined {0}".format(activation)) + + def _dropout(self, logit, layer_idx): + logit = npu_ops.dropout(x=logit, keep_prob=self.layer_keeps[layer_idx]) + return logit + + def train(self, sess): + return sess.run([self.update, self.loss, self.data_loss, self.merged], \ + feed_dict={self.layer_keeps: self.keep_prob_train}) + + def eval(self, sess, eval_label): + return sess.run([self.loss, self.data_loss, self.pred, eval_label], \ + feed_dict={self.layer_keeps: self.keep_prob_test}) + + def infer(self, sess): + return sess.run([self.pred], \ + feed_dict={self.layer_keeps: self.keep_prob_test}) + diff --git a/examples/xDeepFM/src/exDeepFM.py b/examples/xDeepFM/src/exDeepFM.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5b529993404c647cff86d118f384f90e5c967a --- /dev/null +++ b/examples/xDeepFM/src/exDeepFM.py @@ -0,0 +1,425 @@ +"""define Factorization-Machine based Neural Network Model""" +from npu_bridge.npu_init import * +import math +import numpy as np +import tensorflow as tf +from mx_rec.core.embedding import create_table +from mx_rec.core.embedding import sparse_lookup +from src.base_model import BaseModel + +__all__ = ["ExtremeDeepFMModel"] + + +class ExtremeDeepFMModel(BaseModel): + """define Factorization-Machine based Neural Network Model""" + + def _build_graph(self, hparams): + self.keep_prob_train = 1 - np.array(hparams.dropout) + self.keep_prob_test = np.ones_like(hparams.dropout) + self.layer_keeps = tf.placeholder(tf.float32) + with tf.variable_scope("exDeepFm") as scope: + with tf.variable_scope("embedding", initializer=self.initializer) as escope: + self.embedding = tf.get_variable(name='embedding_layer', + shape=[hparams.FEATURE_COUNT, hparams.dim], + dtype=tf.float32) + self.embed_params.append(self.embedding) + embed_out, embed_layer_size = self._build_embedding(hparams) + logit = self._build_linear(hparams) + # logit = tf.add(logit, self._build_fm(hparams)) + # res: use resnet? direct: without split? reduce_D: Dimension reduction? f_dim: dimension of reduce_D + logit = tf.add(logit, self._build_extreme_FM(hparams, embed_out, res=False, direct=False, bias=False, reduce_D=False, f_dim=2)) + # logit = tf.add(logit, self._build_extreme_FM_quick(hparams, embed_out)) + logit = tf.add(logit, self._build_dnn(hparams, embed_out, embed_layer_size)) + return logit + + def _build_embedding(self, hparams): + fm_sparse_index = tf.SparseTensor(self.iterator.dnn_feat_indices, + self.iterator.dnn_feat_values, + self.iterator.dnn_feat_shape) + fm_sparse_weight = tf.SparseTensor(self.iterator.dnn_feat_indices, + self.iterator.dnn_feat_weights, + self.iterator.dnn_feat_shape) + dense_indices = tf.sparse.to_dense(fm_sparse_index, default_value=0) + dense_weights = tf.sparse.to_dense(fm_sparse_weight, default_value=0) + + sparse_hashtable = create_table(key_dtype=tf.int32, + dim=tf.TensorShape([hparams.dim]), + name='sparse_embeddings_table', + emb_initializer=tf.zeros_initializer(), + device_vocabulary_size=hparams.FEATURE_COUNT, + host_vocabulary_size=0 + ) + embedded_values = sparse_lookup(sparse_hashtable, + dense_indices, + is_train=True, + name="sparse_embeddings", + modify_graph=True) + w_fm_nn_input_orgin = tf.reduce_sum(embedded_values * tf.expand_dims(dense_weights, axis=-1), axis=1) + embedding = tf.reshape(w_fm_nn_input_orgin, [-1, hparams.dim * hparams.FIELD_COUNT]) + embedding_size = hparams.FIELD_COUNT * hparams.dim + return embedding, embedding_size + + def _build_linear(self, hparams): + with tf.variable_scope("linear_part", initializer=self.initializer) as scope: + w_linear = tf.get_variable(name='w', + shape=[hparams.FEATURE_COUNT, 1], + dtype=tf.float32) + b_linear = tf.get_variable(name='b', + shape=[1], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + x = tf.SparseTensor(self.iterator.fm_feat_indices, + self.iterator.fm_feat_values, + self.iterator.fm_feat_shape) + linear_output = tf.add(tf.sparse_tensor_dense_matmul(x, w_linear), b_linear) + self.layer_params.append(w_linear) + self.layer_params.append(b_linear) + tf.summary.histogram("linear_part/w", w_linear) + tf.summary.histogram("linear_part/b", b_linear) + return linear_output + + def _build_fm(self, hparams): + with tf.variable_scope("fm_part") as scope: + x = tf.SparseTensor(self.iterator.fm_feat_indices, + self.iterator.fm_feat_values, + self.iterator.fm_feat_shape) + xx = tf.SparseTensor(self.iterator.fm_feat_indices, + tf.pow(self.iterator.fm_feat_values, 2), + self.iterator.fm_feat_shape) + fm_output = 0.5 * tf.reduce_sum( + tf.pow(tf.sparse_tensor_dense_matmul(x, self.embedding), 2) - \ + tf.sparse_tensor_dense_matmul(xx, + tf.pow(self.embedding, 2)), 1, + keep_dims=True) + return fm_output + """ + def _build_extreme_FM_slow_bad(self, hparams, nn_input): + hidden_nn_layers = [] + field_nums = [] + final_len = 0 + field_num = hparams.FIELD_COUNT + nn_input = tf.reshape(nn_input, shape=[-1, int(field_num), hparams.dim]) + field_nums.append(int(field_num)) + hidden_nn_layers.append(nn_input) + final_result = [] + with tf.variable_scope("exfm_part", initializer=self.initializer) as scope: + for idx, layer_size in enumerate(hparams.cross_layer_sizes): + dot_results = [] + split_tensor = tf.split(hidden_nn_layers[-1], field_nums[-1]*[1], 1) + for s in split_tensor: + s = tf.tile(s, [1, field_nums[0], 1]) + dot_results.append(tf.multiply(s, hidden_nn_layers[0])) + dot_result = tf.concat(dot_results, axis=1) + filters = tf.get_variable(name="f_"+str(idx), + shape=[1, len(dot_results)*field_nums[0], layer_size], + dtype=tf.float32) + dot_result = tf.transpose(dot_result, perm=[0, 2, 1]) + curr_out = tf.nn.conv1d(dot_result, filters=filters, stride=1, padding='VALID') + curr_out = tf.transpose(curr_out, perm=[0, 2, 1]) + + if idx != len(hparams.cross_layer_sizes)-1: + next_hidden, direct_connect = tf.split(curr_out, 2*[int(layer_size / 2)], 1) + final_len += int(layer_size / 2) + else: + direct_connect = curr_out + next_hidden=0 + final_len += layer_size + + ### + direct_connect = curr_out + next_hidden = curr_out + final_len += layer_size + ### + + final_result.append(direct_connect) + hidden_nn_layers.append(next_hidden) + field_nums.append(int(layer_size / 2)) + # field_nums.append(int(layer_size)) + self.cross_params.append(filters) + result = tf.concat(final_result, axis=1) + result = tf.reduce_sum(result, -1) + ### + # residual network + w_nn_output1 = tf.get_variable(name='w_nn_output1', + shape=[final_len, 128], + dtype=tf.float32) + b_nn_output1 = tf.get_variable(name='b_nn_output1', + shape=[128], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + self.layer_params.append(w_nn_output1) + self.layer_params.append(b_nn_output1) + exFM_out0 = tf.nn.xw_plus_b(result, w_nn_output1, b_nn_output1) + exFM_out1 = self._active_layer(logit=exFM_out0, + scope=scope, + activation="relu", + layer_idx=0) + w_nn_output2 = tf.get_variable(name='w_nn_output2', + shape=[128 + final_len, 1], + dtype=tf.float32) + b_nn_output2 = tf.get_variable(name='b_nn_output2', + shape=[1], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + self.layer_params.append(w_nn_output2) + self.layer_params.append(b_nn_output2) + exFM_in = tf.concat([exFM_out1, result], axis=1, name="user_emb") + exFM_out = tf.nn.xw_plus_b(exFM_in, w_nn_output2, b_nn_output2) + + ### + w_nn_output = tf.get_variable(name='w_nn_output', + shape=[final_len, 1], + dtype=tf.float32) + b_nn_output = tf.get_variable(name='b_nn_output', + shape=[1], + dtype=tf.float32) + self.layer_params.append(w_nn_output) + self.layer_params.append(b_nn_output) + exFM_out = tf.nn.xw_plus_b(result, w_nn_output, b_nn_output) + + return exFM_out + """ + + def _build_extreme_FM(self, hparams, nn_input, res=False, direct=False, bias=False, reduce_D=False, f_dim=2): + hidden_nn_layers = [] + field_nums = [] + final_len = 0 + field_num = hparams.FIELD_COUNT + nn_input = tf.reshape(nn_input, shape=[-1, int(field_num), hparams.dim]) + field_nums.append(int(field_num)) + hidden_nn_layers.append(nn_input) + final_result = [] + split_tensor0 = tf.split(hidden_nn_layers[0], hparams.dim * [1], 2) + with tf.variable_scope("exfm_part", initializer=self.initializer) as scope: + for idx, layer_size in enumerate(hparams.cross_layer_sizes): + split_tensor = tf.split(hidden_nn_layers[-1], hparams.dim * [1], 2) + dot_result_m = tf.matmul(split_tensor0, split_tensor, transpose_b=True) + dot_result_o = tf.reshape(dot_result_m, shape=[hparams.dim, -1, field_nums[0]*field_nums[-1]]) + dot_result = tf.transpose(dot_result_o, perm=[1, 0, 2]) + + if reduce_D: + hparams.logger.info("reduce_D") + filters0 = tf.get_variable("f0_" + str(idx), + shape=[1, layer_size, field_nums[0], f_dim], + dtype=tf.float32) + filters_ = tf.get_variable("f__" + str(idx), + shape=[1, layer_size, f_dim, field_nums[-1]], + dtype=tf.float32) + filters_m = tf.matmul(filters0, filters_) + filters_o = tf.reshape(filters_m, shape=[1, layer_size, field_nums[0] * field_nums[-1]]) + filters = tf.transpose(filters_o, perm=[0, 2, 1]) + else: + filters = tf.get_variable(name="f_"+str(idx), + shape=[1, field_nums[-1]*field_nums[0], layer_size], + dtype=tf.float32) + # dot_result = tf.transpose(dot_result, perm=[0, 2, 1]) + curr_out = tf.nn.conv1d(dot_result, filters=filters, stride=1, padding='VALID') + + # BIAS ADD + if bias: + hparams.logger.info("bias") + b = tf.get_variable(name="f_b" + str(idx), + shape=[layer_size], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + curr_out = tf.nn.bias_add(curr_out, b) + self.cross_params.append(b) + self.layer_params.append(b) + + curr_out = self._activate(curr_out, hparams.cross_activation) + + curr_out = tf.transpose(curr_out, perm=[0, 2, 1]) + + if direct: + hparams.logger.info("all direct connect") + direct_connect = curr_out + next_hidden = curr_out + final_len += layer_size + field_nums.append(int(layer_size)) + + else: + hparams.logger.info("split connect") + if idx != len(hparams.cross_layer_sizes) - 1: + next_hidden, direct_connect = tf.split(curr_out, 2 * [int(layer_size / 2)], 1) + final_len += int(layer_size / 2) + else: + direct_connect = curr_out + next_hidden = 0 + final_len += layer_size + field_nums.append(int(layer_size / 2)) + + final_result.append(direct_connect) + hidden_nn_layers.append(next_hidden) + + self.cross_params.append(filters) + self.layer_params.append(filters) + + result = tf.concat(final_result, axis=1) + result = tf.reduce_sum(result, -1) + if res: + hparams.logger.info("residual network") + w_nn_output1 = tf.get_variable(name='w_nn_output1', + shape=[final_len, 128], + dtype=tf.float32) + b_nn_output1 = tf.get_variable(name='b_nn_output1', + shape=[128], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + self.layer_params.append(w_nn_output1) + self.layer_params.append(b_nn_output1) + exFM_out0 = tf.nn.xw_plus_b(result, w_nn_output1, b_nn_output1) + exFM_out1 = self._active_layer(logit=exFM_out0, + scope=scope, + activation="relu", + layer_idx=0) + w_nn_output2 = tf.get_variable(name='w_nn_output2', + shape=[128 + final_len, 1], + dtype=tf.float32) + b_nn_output2 = tf.get_variable(name='b_nn_output2', + shape=[1], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + self.layer_params.append(w_nn_output2) + self.layer_params.append(b_nn_output2) + exFM_in = tf.concat([exFM_out1, result], axis=1, name="user_emb") + exFM_out = tf.nn.xw_plus_b(exFM_in, w_nn_output2, b_nn_output2) + + else: + hparams.logger.info("no residual network") + w_nn_output = tf.get_variable(name='w_nn_output', + shape=[final_len, 1], + dtype=tf.float32) + b_nn_output = tf.get_variable(name='b_nn_output', + shape=[1], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + self.layer_params.append(w_nn_output) + self.layer_params.append(b_nn_output) + exFM_out = tf.nn.xw_plus_b(result, w_nn_output, b_nn_output) + + return exFM_out + + def _build_extreme_FM_quick(self, hparams, nn_input): + hidden_nn_layers = [] + field_nums = [] + final_len = 0 + field_num = hparams.FIELD_COUNT + nn_input = tf.reshape(nn_input, shape=[-1, int(field_num), hparams.dim]) + field_nums.append(int(field_num)) + hidden_nn_layers.append(nn_input) + final_result = [] + split_tensor0 = tf.split(hidden_nn_layers[0], hparams.dim * [1], 2) + with tf.variable_scope("exfm_part", initializer=self.initializer) as scope: + for idx, layer_size in enumerate(hparams.cross_layer_sizes): + split_tensor = tf.split(hidden_nn_layers[-1], hparams.dim * [1], 2) + dot_result_m = tf.matmul(split_tensor0, split_tensor, transpose_b=True) + dot_result_o = tf.reshape(dot_result_m, shape=[hparams.dim, -1, field_nums[0]*field_nums[-1]]) + dot_result = tf.transpose(dot_result_o, perm=[1, 0, 2]) + + filters = tf.get_variable(name="f_"+str(idx), + shape=[1, field_nums[-1]*field_nums[0], layer_size], + dtype=tf.float32) + # dot_result = tf.transpose(dot_result, perm=[0, 2, 1]) + curr_out = tf.nn.conv1d(dot_result, filters=filters, stride=1, padding='VALID') + + + curr_out = tf.transpose(curr_out, perm=[0, 2, 1]) + + + hparams.logger.info("split connect") + if idx != len(hparams.cross_layer_sizes) - 1: + next_hidden, direct_connect = tf.split(curr_out, 2 * [int(layer_size / 2)], 1) + final_len += int(layer_size / 2) + else: + direct_connect = curr_out + next_hidden = 0 + final_len += layer_size + field_nums.append(int(layer_size / 2)) + + final_result.append(direct_connect) + hidden_nn_layers.append(next_hidden) + + self.cross_params.append(filters) + + result = tf.concat(final_result, axis=1) + result = tf.reduce_sum(result, -1) + + hparams.logger.info("no residual network") + w_nn_output = tf.get_variable(name='w_nn_output', + shape=[final_len, 1], + dtype=tf.float32) + b_nn_output = tf.get_variable(name='b_nn_output', + shape=[1], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + self.layer_params.append(w_nn_output) + self.layer_params.append(b_nn_output) + exFM_out = tf.nn.xw_plus_b(result, w_nn_output, b_nn_output) + + return exFM_out + + + def _build_dnn(self, hparams, embed_out, embed_layer_size): + """ + fm_sparse_index = tf.SparseTensor(self.iterator.dnn_feat_indices, + self.iterator.dnn_feat_values, + self.iterator.dnn_feat_shape) + fm_sparse_weight = tf.SparseTensor(self.iterator.dnn_feat_indices, + self.iterator.dnn_feat_weights, + self.iterator.dnn_feat_shape) + w_fm_nn_input_orgin = tf.nn.embedding_lookup_sparse(self.embedding, + fm_sparse_index, + fm_sparse_weight, + combiner="sum") + w_fm_nn_input = tf.reshape(w_fm_nn_input_orgin, [-1, hparams.dim * hparams.FIELD_COUNT]) + last_layer_size = hparams.FIELD_COUNT * hparams.dim + """ + w_fm_nn_input = embed_out + last_layer_size = embed_layer_size + layer_idx = 0 + hidden_nn_layers = [] + hidden_nn_layers.append(w_fm_nn_input) + with tf.variable_scope("nn_part", initializer=self.initializer) as scope: + for idx, layer_size in enumerate(hparams.layer_sizes): + curr_w_nn_layer = tf.get_variable(name='w_nn_layer' + str(layer_idx), + shape=[last_layer_size, layer_size], + dtype=tf.float32) + curr_b_nn_layer = tf.get_variable(name='b_nn_layer' + str(layer_idx), + shape=[layer_size], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + tf.summary.histogram("nn_part/" + 'w_nn_layer' + str(layer_idx), + curr_w_nn_layer) + tf.summary.histogram("nn_part/" + 'b_nn_layer' + str(layer_idx), + curr_b_nn_layer) + curr_hidden_nn_layer = tf.nn.xw_plus_b(hidden_nn_layers[layer_idx], + curr_w_nn_layer, + curr_b_nn_layer) + scope = "nn_part" + str(idx) + activation = hparams.activation[idx] + curr_hidden_nn_layer = self._active_layer(logit=curr_hidden_nn_layer, + scope=scope, + activation=activation, + layer_idx=idx) + hidden_nn_layers.append(curr_hidden_nn_layer) + layer_idx += 1 + last_layer_size = layer_size + self.layer_params.append(curr_w_nn_layer) + self.layer_params.append(curr_b_nn_layer) + + w_nn_output = tf.get_variable(name='w_nn_output', + shape=[last_layer_size, 1], + dtype=tf.float32) + b_nn_output = tf.get_variable(name='b_nn_output', + shape=[1], + dtype=tf.float32, + initializer=tf.zeros_initializer()) + tf.summary.histogram("nn_part/" + 'w_nn_output' + str(layer_idx), + w_nn_output) + tf.summary.histogram("nn_part/" + 'b_nn_output' + str(layer_idx), + b_nn_output) + self.layer_params.append(w_nn_output) + self.layer_params.append(b_nn_output) + nn_output = tf.nn.xw_plus_b(hidden_nn_layers[-1], w_nn_output, b_nn_output) + return nn_output + diff --git a/examples/xDeepFM/train.py b/examples/xDeepFM/train.py new file mode 100644 index 0000000000000000000000000000000000000000..39918b34a795fdd62b6b3c96c5993f93e8fe2a10 --- /dev/null +++ b/examples/xDeepFM/train.py @@ -0,0 +1,310 @@ +"""define train, infer, eval, test process""" +from npu_bridge.npu_init import * +import numpy as np +import os, time, collections +import tensorflow as tf +from IO.iterator import FfmIterator #, DinIterator, CCCFNetIterator +#from IO.din_cache import DinCache +from IO.ffm_cache import FfmCache +#from IO.cccfnet_cache import CCCFNetCache +#from src.deep_fm import DeepfmModel +#from src.deep_wide import DeepWideModel +#from src.fm import FmModel +#from src.dnn import DnnModel +#from src.opnn import OpnnModel +#from src.ipnn import IpnnModel +#from src.lr import LrModel +#from src.din import DinModel +#from src.cccfnet import CCCFModel +#from src.deepcross import DeepCrossModel +from src.exDeepFM import ExtremeDeepFMModel +#from src.cross import CrossModel +import utils.util as util +import utils.metric as metric +# from utils.log import Log + +# log = Log(hparams) + +class TrainModel(collections.namedtuple("TrainModel", ("graph", "model", "iterator", "filenames"))): + """define train class, include graph, model, iterator""" + pass + + +def create_train_model(model_creator, hparams, scope=None): + # feed train file name, valid file name, or test file name + filenames = tf.placeholder(tf.string, shape=[None]) + # src_dataset = tf.contrib.data.TFRecordDataset(filenames) + src_dataset = tf.data.TFRecordDataset(filenames) + + if hparams.data_format == 'ffm': + batch_input = FfmIterator(src_dataset) + elif hparams.data_format == 'din': + batch_input = DinIterator(src_dataset) + elif hparams.data_format == 'cccfnet': + batch_input = CCCFNetIterator(src_dataset) + else: + raise ValueError("not support {0} format data".format(hparams.data_format)) + # build model + model = model_creator( + hparams, + iterator=batch_input, + scope=scope) + + return TrainModel( + graph=tf.get_default_graph(), + model=model, + iterator=batch_input, + filenames=filenames) + + +# run evaluation and get evaluted loss +def run_eval(load_model, load_sess, filename, sample_num_file, hparams, flag): + # load sample num + with open(sample_num_file, 'r') as f: + sample_num = int(f.readlines()[0].strip()) + from mx_rec.util.initialize import ConfigInitializer + eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(True).get("labels") + initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True) + load_sess.run(initializer, feed_dict={load_model.filenames: [filename]}) + preds = [] + labels = [] + while True: + try: + _, _, step_pred, step_labels = load_model.model.eval(load_sess, eval_label) + preds.extend(np.reshape(step_pred, -1)) + labels.extend(np.reshape(step_labels, -1)) + except tf.errors.OutOfRangeError: + break + preds = preds[:sample_num] + labels = labels[:sample_num] + hparams.logger.info("data num:{0:d}".format(len(labels))) + res = metric.cal_metric(labels, preds, hparams, flag) + return res + + +# run infer +def run_infer(load_model, load_sess, filename, hparams, sample_num_file): + # load sample num + with open(sample_num_file, 'r') as f: + sample_num = int(f.readlines()[0].strip()) + if not os.path.exists(util.RES_DIR): + os.mkdir(util.RES_DIR) + load_sess.run(load_model.iterator.initializer, feed_dict={load_model.filenames: [filename]}) + preds = [] + while True: + try: + step_pred = load_model.model.infer(load_sess) + preds.extend(np.reshape(step_pred, -1)) + except tf.errors.OutOfRangeError: + break + preds = preds[:sample_num] + hparams.res_name = util.convert_res_name(hparams.infer_file) + # print('result name:', hparams.res_name) + with open(hparams.res_name, 'w') as out: + out.write('\n'.join(map(str, preds))) + + +# cache data +def cache_data(hparams, filename, flag): + if hparams.data_format == 'ffm': + cache_obj = FfmCache() + elif hparams.data_format == 'din': + cache_obj = DinCache() + elif hparams.data_format == 'cccfnet': + cache_obj = CCCFNetCache() + else: + raise ValueError( + "data format must be ffm, din, cccfnet, this format not defined {0}".format(hparams.data_format)) + if not os.path.exists(util.CACHE_DIR): + os.mkdir(util.CACHE_DIR) + if flag == 'train': + hparams.train_file_cache = util.convert_cached_name(hparams.train_file, hparams.batch_size) + cached_name = hparams.train_file_cache + sample_num_path = util.TRAIN_NUM + impression_id_path = util.TRAIN_IMPRESSION_ID + elif flag == 'eval': + hparams.eval_file_cache = util.convert_cached_name(hparams.eval_file, hparams.batch_size) + cached_name = hparams.eval_file_cache + sample_num_path = util.EVAL_NUM + impression_id_path = util.EVAL_IMPRESSION_ID + elif flag == 'test': + hparams.test_file_cache = util.convert_cached_name(hparams.test_file, hparams.batch_size) + cached_name = hparams.test_file_cache + sample_num_path = util.TEST_NUM + impression_id_path = util.TEST_IMPRESSION_ID + elif flag == 'infer': + hparams.infer_file_cache = util.convert_cached_name(hparams.infer_file, hparams.batch_size) + cached_name = hparams.infer_file_cache + sample_num_path = util.INFER_NUM + impression_id_path = util.INFER_IMPRESSION_ID + else: + raise ValueError("flag must be train, eval, test, infer") + print('cache filename:', filename) + if not os.path.isfile(cached_name): + print('has not cached file, begin cached...') + start_time = time.time() + sample_num, impression_id_list = cache_obj.write_tfrecord(filename, cached_name, hparams) + util.print_time("caced file used time", start_time) + print("data sample num:{0}".format(sample_num)) + with open(sample_num_path, 'w') as f: + f.write(str(sample_num) + '\n') + with open(impression_id_path, 'w') as f: + for impression_id in impression_id_list: + f.write(str(impression_id) + '\n') + + +def train(hparams, scope=None, target_session=""): + params = hparams.values() + for key, val in params.items(): + hparams.logger.info(str(key) + ':' + str(val)) + + print('load and cache data...') + if hparams.train_file is not None: + cache_data(hparams, hparams.train_file, flag='train') + if hparams.eval_file is not None: + cache_data(hparams, hparams.eval_file, flag='eval') + if hparams.test_file is not None: + cache_data(hparams, hparams.test_file, flag='test') + if hparams.infer_file is not None: + cache_data(hparams, hparams.infer_file, flag='infer') + + if hparams.model_type == 'deepFM': + model_creator = DeepfmModel + print("run deepfm model!") + elif hparams.model_type == 'deepWide': + model_creator = DeepWideModel + print("run deepWide model!") + elif hparams.model_type == 'dnn': + print("run dnn model!") + model_creator = DnnModel + elif hparams.model_type == 'ipnn': + print("run ipnn model!") + model_creator = IpnnModel + elif hparams.model_type == 'opnn': + print("run opnn model!") + model_creator = OpnnModel + elif hparams.model_type == 'din': + print("run din model!") + model_creator = DinModel + elif hparams.model_type == 'fm': + print("run fm model!") + model_creator = FmModel + elif hparams.model_type == 'lr': + print("run lr model!") + model_creator = LrModel + elif hparams.model_type == 'din': + print("run din model!") + model_creator = DinModel + elif hparams.model_type == 'cccfnet': + print("run cccfnet model!") + model_creator = CCCFModel + elif hparams.model_type == 'deepcross': + print("run deepcross model!") + model_creator = DeepCrossModel + elif hparams.model_type == 'exDeepFM': + print("run extreme deepFM model!") + model_creator = ExtremeDeepFMModel + elif hparams.model_type == 'cross': + print("run extreme cross model!") + model_creator = CrossModel + + else: + raise ValueError("model type should be cccfnet, deepFM, deepWide, dnn, fm, lr, ipnn, opnn, din") + + # define train,eval,infer graph + # define train session, eval session, infer session + train_model = create_train_model(model_creator, hparams, scope) + gpuconfig = tf.ConfigProto() + gpuconfig.gpu_options.allow_growth = True + tf.set_random_seed(1234) + + from mx_rec.graph.modifier import modify_graph_and_start_emb_cache + modify_graph_and_start_emb_cache(dump_graph=True) + + train_sess = tf.Session(target=target_session, graph=train_model.graph, config=npu_config_proto(config_proto=gpuconfig)) + + train_sess.run(train_model.model.init_op) + # load model from checkpoint + if not hparams.load_model_name is None: + checkpoint_path = hparams.load_model_name + try: + train_model.model.saver.restore(train_sess, checkpoint_path) + print('load model', checkpoint_path) + except: + raise IOError("Failed to find any matching files for {0}".format(checkpoint_path)) + print('total_loss = data_loss+regularization_loss, data_loss = {rmse or logloss ..}') + writer = tf.summary.FileWriter(util.SUMMARIES_DIR, train_sess.graph) + last_eval = 0 + for epoch in range(hparams.epochs): + step = 0 + from mx_rec.util.initialize import ConfigInitializer + initializer = ConfigInitializer.get_instance().train_params_config.get_initializer(True) + train_sess.run(initializer, feed_dict={train_model.filenames: [hparams.train_file_cache]}) + + epoch_loss = 0 + train_start = time.time() + train_load_time = 0 + while True: + try: + t1 = time.time() + step_result = train_model.model.train(train_sess) + t3 = time.time() + train_load_time += t3 - t1 + (_, step_loss, step_data_loss, summary) = step_result + writer.add_summary(summary, step) + epoch_loss += step_loss + step += 1 + if step % hparams.show_step == 0: + print('step {0:d} , total_loss: {1:.4f}, data_loss: {2:.4f}' \ + .format(step, step_loss, step_data_loss)) + except tf.errors.OutOfRangeError: + print('finish one epoch!') + break + train_end = time.time() + train_time = train_end - train_start + if epoch % hparams.save_epoch == 0: + checkpoint_path = train_model.model.saver.save( + sess=train_sess, + save_path=util.MODEL_DIR + 'epoch_' + str(epoch)) + # print(checkpoint_path) + train_res = dict() + train_res["loss"] = epoch_loss / step + eval_start = time.time() + # train_res = run_eval(train_model, train_sess, hparams.train_file_cache, util.TRAIN_NUM, hparams, flag='train') + eval_res = run_eval(train_model, train_sess, hparams.eval_file_cache, util.EVAL_NUM, hparams, flag='eval') + train_info = ', '.join( + [str(item[0]) + ':' + str(item[1]) + for item in sorted(train_res.items(), key=lambda x: x[0])]) + eval_info = ', '.join( + [str(item[0]) + ':' + str(item[1]) + for item in sorted(eval_res.items(), key=lambda x: x[0])]) + if hparams.test_file is not None: + test_res = run_eval(train_model, train_sess, hparams.test_file_cache, util.TEST_NUM, hparams, flag='test') + test_info = ', '.join( + [str(item[0]) + ':' + str(item[1]) + for item in sorted(test_res.items(), key=lambda x: x[0])]) + eval_end = time.time() + eval_time = eval_end - eval_start + if hparams.test_file is not None: + print('at epoch {0:d}'.format( + epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info + ' test info: ' + test_info) + hparams.logger.info('at epoch {0:d}'.format( + epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info + ' test info: ' + test_info) + else: + print('at epoch {0:d}'.format(epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info) + hparams.logger.info('at epoch {0:d}'.format(epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info) + print('at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}'.format(epoch, train_time, eval_time)) + + hparams.logger.info('at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}' \ + .format(epoch, train_time, eval_time)) + hparams.logger.info('\n') + + if eval_res["auc"] - last_eval < - 0.003: + break + if eval_res["auc"] > last_eval: + last_eval = eval_res["auc"] + writer.close() + # after train,run infer + if hparams.infer_file is not None: + run_infer(train_model, train_sess, hparams.infer_file_cache, hparams, util.INFER_NUM) + diff --git a/examples/xDeepFM/utils/log.py b/examples/xDeepFM/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0c2c921e5425a83f5de3e37f2553270d112aec --- /dev/null +++ b/examples/xDeepFM/utils/log.py @@ -0,0 +1,22 @@ +"""define logging configure""" +from npu_bridge.npu_init import * +import logging +from datetime import datetime, timedelta, timezone +import platform + +__all__ = ["Log"] +class Log(object): + def __init__(self, hparams): + # UTC To Beijing Time + utc_dt = datetime.utcnow().replace(tzinfo=timezone.utc) + bj_dt = utc_dt.astimezone(timezone(timedelta(hours=8))) + + logging_filename = "logs/"+hparams.log + '__' + bj_dt.strftime('%Y-%m-%d_%H_%M_%S') + '.log' + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + handler = logging.FileHandler(logging_filename) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + diff --git a/examples/xDeepFM/utils/metric.py b/examples/xDeepFM/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..2a44b47b1076b0e6ff4310cafb05352918057d91 --- /dev/null +++ b/examples/xDeepFM/utils/metric.py @@ -0,0 +1,99 @@ +"""define metrics""" +from npu_bridge.npu_init import * +from collections import defaultdict +from sklearn.metrics import roc_auc_score, log_loss, mean_squared_error +import numpy as np +import utils.util as util + + +def cal_metric(labels, preds, hparams, flag): + """Calculate metrics,such as auc, logloss, group auc""" + res = {} + + def load_impression_id(file_name): + """load impression id, such as user id, news id""" + id_list = [] + with open(file_name, 'r') as f_in: + for line in f_in: + id_list.append(line.strip()) + return id_list + + for metric in hparams.metrics: + if metric == 'auc': + auc = roc_auc_score(np.asarray(labels), np.asarray(preds)) + res['auc'] = round(auc, 4) + elif metric == 'rmse': + rmse = mean_squared_error(np.asarray(labels), np.asarray(preds)) + res['rmse'] = np.sqrt(round(rmse, 4)) + elif metric == 'logloss': + # avoid logloss nan + preds = [max(min(p, 1. - 10e-12), 10e-12) for p in preds] + logloss = log_loss(np.asarray(labels), np.asarray(preds)) + res['logloss'] = round(logloss, 4) + elif metric == 'group_auc': + if flag == 'train': + impression_id_list = load_impression_id(util.TRAIN_IMPRESSION_ID) + if len(impression_id_list) == 0: + raise ValueError("train data does not has impressionId," \ + "so can not cal the group auc!") + group_auc = cal_group_auc(labels, preds, impression_id_list) + res['group_auc'] = group_auc + elif flag == 'eval': + impression_id_list = load_impression_id(util.EVAL_IMPRESSION_ID) + if len(impression_id_list) == 0: + raise ValueError("eval data does not has impressionId," \ + "so can not cal the group auc!") + group_auc = cal_group_auc(labels, preds, impression_id_list) + res['group_auc'] = group_auc + elif flag == 'test': + impression_id_list = load_impression_id(util.INFER_IMPRESSION_ID) + if len(impression_id_list) == 0: + raise ValueError("infer data does not has impressionId," \ + "so can not cal the group auc!") + group_auc = cal_group_auc(labels, preds, impression_id_list) + res['group_auc'] = group_auc + else: + raise ValueError("cal metric dataSet should be train, eval , test") + + else: + raise ValueError("not define this metric {0}".format(metric)) + return res + + +def cal_group_auc(labels, preds, impression_id_list): + """Calculate group auc""" + if len(impression_id_list) != len(labels): + raise ValueError( + "impression id num should equal to the sample num," \ + "impression id num is {0}".format(len(impression_id_list))) + group_score = defaultdict(lambda: []) + group_truth = defaultdict(lambda: []) + for idx, truth in enumerate(labels): + user_id = impression_id_list[idx] + score = preds[idx] + truth = labels[idx] + group_score[user_id].append(score) + group_truth[user_id].append(truth) + + group_flag = defaultdict(lambda: False) + for user_id in set(impression_id_list): + truths = group_truth[user_id] + flag = False + for i in range(len(truths) - 1): + if truths[i] != truths[i + 1]: + flag = True + break + group_flag[user_id] = flag + + impression_total = 0 + total_auc = 0 + # + for user_id in group_flag: + if group_flag[user_id]: + auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) + total_auc += auc * len(group_truth[user_id]) + impression_total += len(group_truth[user_id]) + group_auc = float(total_auc) / impression_total + group_auc = round(group_auc, 4) + return group_auc + diff --git a/examples/xDeepFM/utils/util.py b/examples/xDeepFM/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d886369f125f8e31d5f3709314a580668c5bef --- /dev/null +++ b/examples/xDeepFM/utils/util.py @@ -0,0 +1,85 @@ +"""define util function and global variable""" +from npu_bridge.npu_init import * +import tensorflow as tf +import os, sys +import time, yaml +from packaging import version + +RES_DIR = './res/' +CACHE_DIR = './cache/' +MODEL_DIR = './checkpoint/' +CONFIG_DIR = './config/' +TRAIN_YAML = 'network.yaml' +TRAIN_NUM = './cache/train_num.csv' +EVAL_NUM = './cache/eval_num.csv' +TEST_NUM = './cache/test_num.csv' +INFER_NUM = './cache/infer_num.csv' +LOG_DIR = './logs/' +FEAT_COUNT_FILE = './cache/feat_cnt.csv' +TRAIN_IMPRESSION_ID = './cache/train_impressionId.csv' +EVAL_IMPRESSION_ID = './cache/eval_impressionId.csv' +TEST_IMPRESSION_ID = './cache/test_impressionId.csv' +INFER_IMPRESSION_ID = './cache/infer_impressionId.csv' +SUMMARIES_DIR = './logs/' +# define din format feature +DIN_FORMAT_SPLIT = '#' +# split feature and userid +USER_ID_SPLIT = '%' + + +def check_and_mkdir(): + def make_dir(DIR): + if not os.path.exists(DIR): + os.mkdir(DIR) + + make_dir(RES_DIR) + make_dir(CACHE_DIR) + make_dir(MODEL_DIR) + make_dir(CONFIG_DIR) + make_dir(LOG_DIR) + + +def check_tensorflow_version(): + if version.parse(tf.__version__) < version.parse("1.2.0"): + raise EnvironmentError("Tensorflow version must >= 1.2.0,but version is {0}". \ + format(tf.__version__)) + + +def print_time(s, start_time): + """Take a start time, print elapsed duration, and return a new time.""" + print("%s, %ds, %s." % (s, (time.time() - start_time), time.ctime())) + sys.stdout.flush() + return time.time() + + +def check_file_exist(filename): + if not os.path.isfile(filename): + raise ValueError("{0} is not exits".format(filename)) + + +def load_yaml_file(filename): + with open(filename) as f: + try: + config = yaml.safe_load(f) + except: + raise IOError("load {0} error!".format(filename)) + return config + + +def convert_cached_name(file_name, batch_size): + prefix = CACHE_DIR + 'batch_size_' + str(batch_size) + '_' + prefix += (file_name.strip().split('/'))[-1] + train_cache_name = prefix.replace(".txt", ".tfrecord"). \ + replace(".csv", ".tfrecord"). \ + replace(".libsvm", ".tfrecord") + return train_cache_name + + +def convert_res_name(file_name): + prefix = RES_DIR + inferfile = file_name.split('/')[-1] + res_name = prefix + inferfile.replace("tfrecord", "res.csv"). \ + replace(".csv", ".tfrecord"). \ + replace(".libsvm", ".tfrecord") + return res_name + diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index bdb851311fab5242fb08a65c118a3342a7a9ef83..618d802edabd28c55323bce685d26e636bb57bfe 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -15,17 +15,20 @@ # limitations under the License. # ============================================================================== +__version__ = "5.0.RC2" __all__ = ["version", "__version__"] from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook -from mx_rec.saver.patch import patch_for_saver +from mx_rec.saver.patch import patch_for_saver, patch_for_summary_writer from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ patch_for_assert_eval_spec, patch_for_scale_loss, patch_for_session from mx_rec.data.patch import patch_for_dataset_eos_map from mx_rec.optimizers.base import patch_for_optimizer +from mx_rec.saver.warm_start import patch_for_warm_start patch_for_saver() +patch_for_summary_writer() patch_for_dataset() patch_for_dataset_eos_map() patch_for_scale_loss() @@ -34,7 +37,7 @@ patch_for_assert_eval_spec() patch_for_bool_gauge() patch_for_optimizer() patch_for_session() -__version__ = "5.0.RC2" +patch_for_warm_start() def version(): diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 03fa28b429f1176dcc6e4e5e8eb99dc969e14524..f8558cd98722c04fa6030cb739d553a2ab615a3c 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from enum import Enum import numpy as np @@ -22,16 +21,19 @@ ASCEND_GLOBAL_HASHTABLE_COLLECTION = "ASCEND_GLOBAL_HASHTABLE_COLLECTION" ASCEND_CUTTING_POINT_INITIALIZER = "ASCEND_CUTTING_POINT_INITIALIZER" ASCEND_SPARSE_LOOKUP_ENTRANCE = "ASCEND_SPARSE_LOOKUP_ENTRANCE" ASCEND_SPARSE_LOOKUP_ID_OFFSET = "ASCEND_SPARSE_LOOKUP_ID_OFFSET" -ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS = "ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS" ASCEND_TIMESTAMP = "ASCEND_TIMESTAMP" ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB" EMPTY_STR = "" +# default emb memory size for hbm、ddr、ssd +DEFAULT_DEVICE_CACHE_MEMORY_SIZE = 2 * 1024 * 1024 * 1024 +DEFAULT_HOST_CACHE_MEMORY_SIZE = 40 * 1024 * 1024 * 1024 + # 获取ConfigInitializer对象实例失败提示信息 GET_CONFIG_INSTANCE_ERR_MSG = "Please init the environment for mx_rec at first." -# 自动改图模式下从计算图中寻找dataset的锚点名称 -ANCHOR_DATASET_NAME = "PrefetchDataset" +# Used for slicer finding the orphan lookup key. +ORPHAN_LOOKUP_KEY_PREFIX = "orphan" # the name of the embedding table merged by third party ASCEND_TABLE_NAME_MUST_CONTAIN = None @@ -44,6 +46,11 @@ DEFAULT_HD_CHANNEL_SIZE = 40 MAX_HD_CHANNEL_SIZE = 8192 MIN_HD_CHANNEL_SIZE = 2 +# CM_WORKER_SIZE集群节点数 +DEFAULT_CM_WORKER_SIZE = 0 +MAX_CM_WORKER_SIZE = 512 +MIN_CM_WORKER_SIZE = 0 + # key process线程数 DEFAULT_KP_THREAD_NUM = 6 MIN_KP_THREAD_NUM = 1 @@ -64,7 +71,7 @@ DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 -MAX_VOCABULARY_SIZE = 10**10 +MAX_VOCABULARY_SIZE = 10**9 MAX_DEVICE_VOCABULARY_SIZE = 10 ** 9 # RANK INFO @@ -117,7 +124,6 @@ class BaseEnum(Enum): class EnvOption(Enum): MXREC_LOG_LEVEL = "MXREC_LOG_LEVEL" RANK_TABLE_FILE = "RANK_TABLE_FILE" - ASCEND_VISIBLE_DEVICES = "ASCEND_VISIBLE_DEVICES" CM_CHIEF_DEVICE = "CM_CHIEF_DEVICE" CM_WORKER_SIZE = "CM_WORKER_SIZE" TF_DEVICE = "TF_DEVICE" @@ -139,6 +145,12 @@ class EnvOption(Enum): OMPI_COMM_WORLD_RANK = "OMPI_COMM_WORLD_RANK" +class CacheModeEnum(Enum): + HBM = "HBM" + DDR = "DDR" + SSD = "SSD" + + class DataName(Enum): KEY = "key" EMBEDDING = "embedding" @@ -166,8 +178,9 @@ class ASCAnchorAttr(Enum): MOCK_LOOKUP_RESULT = "mock_lookup_result" RESTORE_VECTOR_SECOND = "restore_vector_second" UNIQUE_KEYS = "unique_keys" - GRADIENTS_STRATEGY = "gradients_strategy" IS_GRAD = "is_grad" + TABLE_NAME = "table_name" + CHANNEL_ID = "channel_id" class OptimizerType(Enum): @@ -214,16 +227,3 @@ class TFDevice(Enum): class Flag(Enum): TRUE = "1" FALSE = "0" - - -class AnchorDatasetOp(Enum): - MODEL_DATASET = "ModelDataset" - OPTIMIZE_DATASET = "OptimizeDataset" - PREFETCH_DATASET = "PrefetchDataset" - - -class AnchorIteratorOp(Enum): - ITERATOR_GET_NEXT = "IteratorGetNext" - MAKE_ITERATOR = "MakeIterator" - ONE_SHOT_ITERATOR = "OneShotIterator" - diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 13ddad4accbc881f1fc714a95531b305675facff..00b9d282cf92c8a575495a06057aab906ba7fd0c 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -15,19 +15,30 @@ # limitations under the License. # ============================================================================== -from typing import Optional +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Union, Tuple import tensorflow as tf import mxrec_pybind +from mx_rec.constants.constants import ASCAnchorAttr from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.tf_version_adapter import npu_ops -from mx_rec.constants.constants import TRAIN_CHANNEL_ID from mx_rec.util.log import logger +from mx_rec.core.asc.swap_args import SwapArgs, SwapDataType + + +@dataclass +class SwapInfo: + swap_in_len: int = 0 + swap_in_pos: List[tf.Tensor] = field(default_factory=lambda: []) + swap_out_len: int = 0 + swap_out_pos: List[tf.Tensor] = field(default_factory=lambda: []) def get_restore_vector(config): - logger.debug('Channel %s_restore_%s was built for getnext', config.get("table_name"), config.get("channel_id")) + logger.debug('Channel %s_restore_%s was built for getnext', config.get(ASCAnchorAttr.TABLE_NAME.value), + config.get(ASCAnchorAttr.CHANNEL_ID.value)) if config.get("is_hbm"): if not isinstance(config.get("emb_size"), int) or config.get("emb_size") < 1: raise TypeError(f"emb_size must be a int") @@ -39,86 +50,60 @@ def get_restore_vector(config): raise TypeError("ext_emb_size must be a int") if config.get("ext_emb_size") < 1: raise ValueError("ext_emb_size is less than 1") - emb_size = config.get("ext_emb_size") + emb_size = config.get("emb_size") if ConfigInitializer.get_instance().use_static: restore_size = config.get("batch_size") * config.get("feat_cnt") + device_id = int(config.get("device_id")) + hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) else: restore_size = None + hot_size = None - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - device_id = int(config.get("device_id")) - hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) + with tf.compat.v1.variable_scope(config.get(ASCAnchorAttr.TABLE_NAME.value), reuse=tf.compat.v1.AUTO_REUSE): restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32, tf.int32], output_shapes=[restore_size, [hot_size]], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}') + channel_name=f'{config.get(ASCAnchorAttr.TABLE_NAME.value)}' + f'_restore_{config.get(ASCAnchorAttr.CHANNEL_ID.value)}') return restore_vector, hot_pos -def get_id_offsets(max_lookup_vec_size, config): - logger.debug('Channel %s_lookup_%s was built for getnext', config.get("table_name"), config.get("channel_id")) +def get_id_offsets(max_lookup_vec_size: int, config: dict) -> Tuple[int, SwapInfo]: + logger.debug('Channel %s_lookup_%s was built for getnext', config.get(ASCAnchorAttr.TABLE_NAME.value), + config.get(ASCAnchorAttr.CHANNEL_ID.value)) # 自动扩容当前只支持HBM模式,默认没有换入换出 - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + swap_info = SwapInfo() + + with tf.compat.v1.variable_scope(config.get(ASCAnchorAttr.TABLE_NAME.value), reuse=tf.compat.v1.AUTO_REUSE): if config.get("use_dynamic_expansion"): [id_offsets] = npu_ops.gen_npu_ops.get_next( output_types=[tf.int64], output_shapes=[[max_lookup_vec_size]], - channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') - return id_offsets, [], 0 + channel_name=f'{config.get(ASCAnchorAttr.TABLE_NAME.value)}' + f'_lookup_{config.get(ASCAnchorAttr.CHANNEL_ID.value)}') + return id_offsets, swap_info [id_offsets] = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32], output_shapes=[[max_lookup_vec_size]], - channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') + channel_name=f'{config.get(ASCAnchorAttr.TABLE_NAME.value)}' + f'_lookup_{config.get(ASCAnchorAttr.CHANNEL_ID.value)}') if config.get("is_hbm"): - return id_offsets, [], 0 - swap_pos, swap_len = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[[max_lookup_vec_size], []], - channel_name=f'{config.get("table_name")}_swap_{config.get("channel_id")}') - return id_offsets, swap_pos, swap_len - - -def get_restore_vector_second(max_lookup_vec_size: int, config: dict) -> tf.Tensor: - """ - Get restore vector which is calculated after the second all2all - :param max_lookup_vec_size: the size of restore_vector_second - :param config: embedding config - :return: the restore vector calculated after the second all2all - """ - logger.debug('Channel %s_restore_second_%s was built for getnext', - config.get("table_name"), config.get("channel_id")) - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - restore_vector_second = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[[max_lookup_vec_size]], - channel_name=f'{config.get("table_name")}_restore_second_{config.get("channel_id")}')[0] - return restore_vector_second - - -def get_unique_keys(max_lookup_vec_size: int, config: dict) -> tf.Tensor: - """ - Get the global unique keys which is calculated after the second all2all - :param max_lookup_vec_size: the size of global unique keys - :param config: embedding config - :return: the global unique keys calculated after the second all2all - """ - logger.debug('Channel %s_uniquekeys_%s was built for getnext', config.get("table_name"), config.get("channel_id")) - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - if config.get("use_dynamic_expansion"): - unique_keys = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int64], - output_shapes=[[max_lookup_vec_size]], - channel_name=f'{config.get("table_name")}_uniquekeys_{config.get("channel_id")}')[0] - return unique_keys - - unique_keys = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[[max_lookup_vec_size]], - channel_name=f'{config.get("table_name")}_uniquekeys_{config.get("channel_id")}')[0] - return unique_keys + return id_offsets, swap_info + ( + swap_info.swap_in_pos, + swap_info.swap_out_pos, + swap_info.swap_in_len, + swap_info.swap_out_len, + ) = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32, tf.int32, tf.int32], + output_shapes=[[max_lookup_vec_size], [max_lookup_vec_size], [], []], + channel_name=f'{config.get(ASCAnchorAttr.TABLE_NAME.value)}_swap_all', + ) + logger.debug('Channel %s_swap_all was built for getnext', config.get(ASCAnchorAttr.TABLE_NAME.value)) + return id_offsets, swap_info def get_all2all_args(use_static: bool, config: dict) -> Optional[list]: @@ -132,61 +117,20 @@ def get_all2all_args(use_static: bool, config: dict) -> Optional[list]: if use_static: return all2all_args - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + with tf.compat.v1.variable_scope(config.get(ASCAnchorAttr.TABLE_NAME.value), reuse=tf.compat.v1.AUTO_REUSE): with tf.compat.v1.variable_scope("all2all"): - logger.debug('Channel %s_a2a_%s was built for getnext', config.get("table_name"), config.get("channel_id")) + logger.debug('Channel %s_a2a_%s was built for getnext', config.get(ASCAnchorAttr.TABLE_NAME.value), + config.get(ASCAnchorAttr.CHANNEL_ID.value)) all2all_args = npu_ops.gen_npu_ops.get_next( output_types=[tf.int64], output_shapes=[[config.get("rank_size"), config.get("rank_size")]], - channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', + channel_name=f'{config.get(ASCAnchorAttr.TABLE_NAME.value)}' + f'_all2all_{config.get(ASCAnchorAttr.CHANNEL_ID.value)}', name="a2a_get_next")[0] * config.get("emb_size") return all2all_args -def get_swap_info(config: dict, swap_len: int, swap_pos: list, table: tf.Variable) -> list: - """ - Get swap info if threshold is configured. - :param config: training job config - :param swap_len: swap length - :param swap_pos: swap position - :param table: the instance to do swap - :return: swap info - """ - use_static = ConfigInitializer.get_instance().use_static - max_lookup_vec_size = None - if use_static: - max_lookup_vec_size = config.get("send_count") * config.get("rank_size") - - if config.get("is_hbm"): - swap_in = [tf.no_op()] - else: - with tf.compat.v1.variable_scope("h2d_emb"): - logger.debug('Channel %s_h2d_%s was built for getnext', config.get("table_name"), config.get("channel_id")) - h2d_emb = npu_ops.gen_npu_ops.get_next( - output_types=[tf.float32], - output_shapes=[[max_lookup_vec_size, config.get("ext_emb_size")]], - channel_name=f'{config.get("table_name")}_h2d_{config.get("channel_id")}')[0] - logger.debug("h2d_emb shape: %s", h2d_emb) - if not isinstance(table, list): - raise RuntimeError("When enable emb_transfer, optimizer should have slots") - if use_static: - swap_pos = swap_pos[0:swap_len] - h2d_emb = h2d_emb[0:swap_len, :] - swap_outs = [tf.gather(one_table, swap_pos) for one_table in table] - swap_out = tf.concat(swap_outs, axis=1) - logger.debug('Channel %s_d2h_%s was built for op outfeed.', config.get("table_name"), config.get("channel_id")) - swap_out_op = npu_ops.outfeed_enqueue_op( - channel_name=f'{config.get("table_name")}_d2h_{config.get("channel_id")}', inputs=[swap_out]) - with tf.control_dependencies([swap_out_op]): - nd_swap_pos = tf.expand_dims(swap_pos, 1) - table_num = len(table) - h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) - swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) - for i in range(len(table))] - return swap_in - - def get_preprocessed_tensor_for_asc(table, config): use_static = ConfigInitializer.get_instance().use_static max_lookup_vec_size = None @@ -197,27 +141,22 @@ def get_preprocessed_tensor_for_asc(table, config): restore_vector, hot_pos = get_restore_vector(config) with tf.compat.v1.variable_scope("id_offsets"): - id_offsets, swap_pos, swap_len = get_id_offsets(max_lookup_vec_size, config) + id_offsets, swap_info = get_id_offsets(max_lookup_vec_size, config) - all2all_args = get_all2all_args(use_static, config) + if not config.get("is_hbm"): + # 一表多查时,会多次进入get_preprocessed_tensor_for_asc,最后一次大查询替换map的key-value即可 + swap_args = SwapArgs() + + swap_args.set_data(SwapDataType.CONFIG.value, var_name=config.get(ASCAnchorAttr.TABLE_NAME.value), + var_channel=config.get(ASCAnchorAttr.CHANNEL_ID.value), config=config, swap_info=swap_info) - swap_in = get_swap_info(config, swap_len, swap_pos, table) + all2all_args = get_all2all_args(use_static, config) result = { 'restore_vector': restore_vector, 'hot_pos': hot_pos, 'id_offsets': id_offsets, - 'swap_in': swap_in, 'all2all_args': all2all_args, } - if config.get("channel_id") != TRAIN_CHANNEL_ID: - return result - - with tf.compat.v1.variable_scope("restore_vector_second"): - restore_vector_second = get_restore_vector_second(max_lookup_vec_size, config) - - with tf.compat.v1.variable_scope("unique_keys"): - unique_keys = get_unique_keys(max_lookup_vec_size, config) - result.update({'restore_vector_second': restore_vector_second, 'unique_keys': unique_keys}) return result diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 771f359f93e51d6db32b03e48502a2fe029d3575..aaa9701701aa0fdca8a54e58ecb299365c8f33d9 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -281,7 +281,7 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): def export_read_emb_key_v2_op(args, pipeline_op): origin_batch = list(args) if len(origin_batch) < 1: - raise ValueError("The length of args is less than 1.") + raise ValueError("the length of args is less than 1.") if isinstance(origin_batch[0], dict): output_batch = origin_batch[0] valid_key = get_valid_op_key(output_batch) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 2829ab9853c909427c64bf00e8c7ca64bb8da7e8..3a24b3d707ec4f527ad296e4bb01df291deb7599 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -18,7 +18,7 @@ import tensorflow as tf from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo, EmbInfo, EmbInfoParams, \ - ThresholdValue, HybridMgmt, RankInfo, USE_STATIC, USE_HOT, USE_DYNAMIC_EXPANSION + ThresholdValue, HybridMgmt, RankInfo, USE_STATIC, USE_DYNAMIC_EXPANSION, USE_SUM_SAME_ID_GRADIENTS from mx_rec.util.communication.hccl_ops import get_rank_id, get_device_id, get_rank_size from mx_rec.util.initialize import ConfigInitializer @@ -37,16 +37,16 @@ def generate_table_info_list(): raise ValueError(f"The DDR mode of all tables must be used or not used at the same time. However, is_hbm " f"of each table `{table_instance_dict.keys()}` is `{is_hbm_list}`.") - optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance + # 通过create_hash_optimizer创建optimizer_instance + optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance # generate table info dangling_table = check_dangling_table() for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): - # When dynamic expansion mode, ext_emb_size is set by optimizer - if ConfigInitializer.get_instance().use_dynamic_expansion and optimizer: - table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num) - logger.debug("ext_emb_size is reset to be %s for EmbInfo", table_instance.ext_emb_size) - + # FS模式扩容场景 + if ConfigInitializer.get_instance().use_dynamic_expansion and optimizer_instance: + table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer_instance.slot_num) + logger.info("ext_emb_size is reset to be %s in generate_table_info_list.", table_instance.ext_emb_size) skip = should_skip(table_instance.table_name) if table_instance.table_name in dangling_table or skip: logger.info("skip table %s: %s which does not need to be provided to the EmbInfo.", @@ -158,9 +158,8 @@ def matched_opt_slot_initializers(table_instance): slot_initializers.append(slot_initializer) start_index += table_instance.emb_size - logger.debug("matched_opt_slot_initializers, ext emb size:%s, optimizer_instance_list size:%s, " - "slot_initializers size:%s", table_instance.ext_emb_size, len(table_instance.optimizer_instance_list), - len(slot_initializers)) + logger.debug("matched_opt_slot_initializers, ext emb size:%s, slot_initializers size:%s", + table_instance.ext_emb_size, len(slot_initializers)) return slot_initializers @@ -195,18 +194,21 @@ def initialize_emb_cache(table_info_list, threshold_list): train_steps = ConfigInitializer.get_instance().train_steps eval_steps = ConfigInitializer.get_instance().eval_steps save_steps = ConfigInitializer.get_instance().save_steps + max_train_steps = ConfigInitializer.get_instance().max_steps if_load = ConfigInitializer.get_instance().if_load option = 0 if ConfigInitializer.get_instance().use_static: option = option | USE_STATIC - # use hot always True - option = option | USE_STATIC << 1 if ConfigInitializer.get_instance().use_dynamic_expansion: option = option | USE_DYNAMIC_EXPANSION - # [train_steps, eval_steps, save_steps] pass step information to HybridMgmt for data process loop - rank_info = RankInfo(rank_id, device_id, rank_size, option, [train_steps, eval_steps, save_steps]) + optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance + if optimizer and optimizer.derivative == 2: + option = option | USE_SUM_SAME_ID_GRADIENTS + + # pass step information to HybridMgmt for data process loop + rank_info = RankInfo(rank_id, device_id, rank_size, option, [train_steps, eval_steps, save_steps, max_train_steps]) emb_cache = HybridMgmt() diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 776a72c43da351a98fc25661406740646e4a7170..fb993032973a6911099a396aecfd8649807b3143 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -196,7 +196,9 @@ def check_dangling_table(): config_instance = ConfigInitializer.get_instance() dangling_table = config_instance.sparse_embed_config.dangling_table if not dangling_table: - dangling_table = find_dangling_table([table_instance.table_name - for _, table_instance in - config_instance.sparse_embed_config.table_instance_dict.items()]) + table_names = [] + for _, table_instance in config_instance.sparse_embed_config.table_instance_dict.items(): + table_names.append(table_instance.table_name) + dangling_table = find_dangling_table(table_names) + return dangling_table diff --git a/mx_rec/core/asc/swap_args.py b/mx_rec/core/asc/swap_args.py new file mode 100644 index 0000000000000000000000000000000000000000..3157e1e0e0d93304ddb5029e8e241e728df31f61 --- /dev/null +++ b/mx_rec/core/asc/swap_args.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import functools +from collections import defaultdict +from enum import Enum + + +class SwapDataType(Enum): + CONFIG = "config" + CONTROL = "control" + CONTROL_OPS = "control_ops" + + +def singleton(cls): + _instance = {} + + @functools.wraps(cls) + def inner(): + if cls not in _instance: + _instance[cls] = cls() + return _instance[cls] + + return inner + + +@singleton +class SwapArgs: + def __init__(self): + self.swap_config_dict = defaultdict(dict) + self.swap_control_dict = defaultdict(dict) + self.slot_control_dict = defaultdict(dict) + + def set_data(self, data_type: str, **kwargs): + if "var_name" not in kwargs: + raise ValueError("Missing Required key: var_name") + if "var_channel" not in kwargs: + raise ValueError("Missing Required key: var_channel") + var_name = kwargs.pop("var_name") + var_channel = kwargs.pop("var_channel") + + if data_type == SwapDataType.CONFIG.value: + self.swap_config_dict[var_name][var_channel] = kwargs + elif data_type == SwapDataType.CONTROL.value: + self.swap_control_dict[var_name][var_channel] = kwargs + else: + raise ValueError(f"Error data type in swap args: {data_type}") + + def set_slot_control(self, **kwargs): + if "var_name" not in kwargs: + raise ValueError("Missing Required key: var_name") + var_name = kwargs.pop("var_name") + self.slot_control_dict[var_name] = kwargs diff --git a/mx_rec/core/emb/base_sparse_embedding.py b/mx_rec/core/emb/base_sparse_embedding.py index 07dc70f752c194e243d9c278869b7d4e2072cd6c..1a59bd24cb9bbc5d8248a0d4882ddae8586c9ead 100644 --- a/mx_rec/core/emb/base_sparse_embedding.py +++ b/mx_rec/core/emb/base_sparse_embedding.py @@ -10,7 +10,9 @@ import tensorflow as tf from tensorflow.python.ops import array_ops from mx_rec.constants.constants import All2allGradientsOp, ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr +from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute, get_feature_spec, FeatureSpec +from mx_rec.core.asc.swap_args import SwapArgs, SwapDataType from mx_rec.util.communication.hccl_ops import get_rank_size, get_rank_id, get_device_id from mx_rec.util.tf_version_adapter import hccl_ops from mx_rec.util.initialize import ConfigInitializer @@ -81,14 +83,6 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection, self._variable) self._set_ext_emb_size() - @property - def optimizer_instance_list(self): - return [] - - @property - def optimizer(self): - return dict() - @property def embedding_size(self): return self._embedding_size @@ -117,6 +111,10 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): def send_count(self): return self._send_count + @property + def rank_size(self): + return self._rank_size + @property def slice_device_vocabulary_size(self): return self._slice_device_vocabulary_size @@ -201,35 +199,11 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): """ pass - @abc.abstractmethod - def set_optimizer(self, key: str, state_dict: dict): - """ - 设置optimizer state. - - Args: - key: 优化器名字 - state_dict: optimizer state - - Returns: None - """ - pass @abc.abstractmethod def _set_slice_vocab_size(self): pass - @abc.abstractmethod - def _set_ext_emb_size(self): - pass - - @abc.abstractmethod - def _build_optimizer_states(self): - pass - - @abc.abstractmethod - def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: - pass - @abc.abstractmethod def _get_update_grad(self, local_grad: tf.Tensor, result: dict, table: Union[tf.compat.v1.Variable, tf.Tensor]) -> Union[tf.IndexedSlices, tf.Tensor]: @@ -322,6 +296,8 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): # set modify graph self._modify_graph = kwargs.get("modify_graph", True) + if not self._modify_graph and not self._is_hbm: + raise RuntimeError("when the 'ddr or ssd' mode are used, the 'modify graph' is required") # return the stub tensor of the lookup result if not self._use_static: @@ -354,7 +330,9 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): return lookup_result if not self._use_static and not self._modify_graph and kwargs.get("batch") is None: - raise RuntimeError("When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") + raise RuntimeError("when the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required") + if not self._modify_graph and not self._is_hbm: + raise RuntimeError("when the 'ddr or ssd' mode are used, the 'modify graph' is required") table_name = feature_spec.table_name same_table_feature_spec = \ ConfigInitializer.get_instance().feature_spec_config.table_name_to_feature_spec[table_name][is_training] @@ -401,6 +379,20 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): return tf.stop_gradient(self._lookup_result.get(spec_name).get(is_training), name="stop_grad_lookup_res") return self._lookup_result.get(spec_name).get(is_training) + def _set_ext_emb_size(self): + # 初始设置_ext_emb_size等于_emb_size,改图阶段会根据优化器的不同而exchange该值 + self._ext_emb_size = self._emb_size * self._ext_coefficient + logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) + + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, channel_id: int, send_count: Optional[int]) -> dict: + config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, + rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, + is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, emb_size=self._emb_size, + use_dynamic_expansion=ConfigInitializer.get_instance().use_dynamic_expansion, + device_id=self._device_id) + + return get_preprocessed_tensor_for_asc(self._variable, config) + def _lookup_forward(self, feature_spec: FeatureSpec, send_count: Optional[int], **kwargs) -> tf.Tensor: is_training = kwargs.get("is_train") hashtable_params = dict(slice_device_vocabulary_size=self._slice_device_vocabulary_size, @@ -409,7 +401,8 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): check_emb_lookup_params(hashtable_params, feature_spec, send_count, is_training) if ConfigInitializer.get_instance().use_static: self._send_count = send_count - result = self._get_preprocessed_tensor(feature_spec, is_training, send_count) + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) + result = self._get_preprocessed_tensor(feature_spec, channel_id, send_count) @tf.custom_gradient def sparse_forward(table): @@ -469,7 +462,11 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): return array_ops.reshape(embeddings, dest_shape), grad - with tf.control_dependencies(result.get("swap_in")): + ddr_control_ops = tf.no_op(name="place_holder_swap_op") + swap_args = SwapArgs() + swap_args.set_data(SwapDataType.CONTROL.value, var_name=self._table_name, var_channel=channel_id, + control_ops=ddr_control_ops) + with tf.control_dependencies([ddr_control_ops]): return self._get_sparse_forward_result(sparse_forward, self._variable, result, is_training) def __initialize_variables(self): @@ -481,7 +478,6 @@ class BaseSparseEmbedding(metaclass=abc.ABCMeta): ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(self._variable.name) self.__record() - self._build_optimizer_states() def __record(self, eval_flag=False): ConfigInitializer.get_instance().sparse_embed_config.insert_table_instance( diff --git a/mx_rec/core/emb/dynamic_sparse_embedding.py b/mx_rec/core/emb/dynamic_sparse_embedding.py index 194b2795129af05a42d0217f2d566a60d7be3ad6..8dfe504c7cba4a648f951ffa5ac7a50d3ea05851 100644 --- a/mx_rec/core/emb/dynamic_sparse_embedding.py +++ b/mx_rec/core/emb/dynamic_sparse_embedding.py @@ -6,12 +6,10 @@ import abc from typing import Optional, Union, Callable import tensorflow as tf -from tensorflow.python.ops import array_ops from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCAnchorAttr from mx_rec.core.asc.feature_spec import FeatureSpec -from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger @@ -29,31 +27,13 @@ class DynamicSparseEmbedding(BaseSparseEmbedding): def capacity(self) -> int: return ConfigInitializer.get_instance().hybrid_manager_config.asc_manager.get_table_capacity(self._table_name) - @abc.abstractmethod - def set_optimizer(self, key: str, state_dict: dict): - pass - - @abc.abstractmethod - def _build_optimizer_states(self): - pass - - @abc.abstractmethod - def _set_ext_emb_size(self): - pass - @abc.abstractmethod def _set_slice_vocab_size(self): pass - @abc.abstractmethod - def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: - pass - def _get_update_grad(self, local_grad: tf.Tensor, result: dict, table: Union[tf.compat.v1.Variable, tf.Tensor]) -> Union[tf.IndexedSlices, tf.Tensor]: - return tf.compat.v1.unsorted_segment_sum(local_grad, - result.get("restore_vector_second"), - array_ops.shape(result.get("unique_keys"))[0]) + return local_grad def _get_local_embeddings(self, table: Union[tf.compat.v1.Variable, tf.Tensor], result: dict, feature_spec: FeatureSpec, **kwargs) -> tf.Tensor: @@ -62,7 +42,7 @@ class DynamicSparseEmbedding(BaseSparseEmbedding): def _get_sparse_forward_result(self, sparse_forward_fn: Callable, table: Union[tf.compat.v1.Variable, tf.Tensor], result: dict, is_training: bool) -> tf.Tensor: local_embeddings = import_host_pipeline_ops().embedding_lookup_by_address( - result.get("id_offsets"), embedding_dim=self._emb_size, embedding_type=1) + result.get(str(ASCAnchorAttr.ID_OFFSETS.value)), embedding_dim=self._emb_size, embedding_type=1) add_collection_condition = is_training and ( ASCEND_TABLE_NAME_MUST_CONTAIN is None or ASCEND_TABLE_NAME_MUST_CONTAIN in self._table_name) @@ -70,9 +50,11 @@ class DynamicSparseEmbedding(BaseSparseEmbedding): self._table_name, ASCEND_TABLE_NAME_MUST_CONTAIN) if not add_collection_condition: return sparse_forward_fn(local_embeddings) - + # 创建扩容查询tensor和table_instance的映射关系,以便优化器中使用 + ConfigInitializer.get_instance().sparse_embed_config.insert_table_instance_to_tensor_dict( + result.get(str(ASCAnchorAttr.ID_OFFSETS.value)), self) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, result.get("unique_keys")) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, result.get(str(ASCAnchorAttr.ID_OFFSETS.value))) return sparse_forward_fn(local_embeddings) @@ -84,25 +66,7 @@ class HBMDynamicSparseEmbedding(DynamicSparseEmbedding): def __init__(self, config: dict): super(DynamicSparseEmbedding, self).__init__(config) - def set_optimizer(self, key: str, state_dict: dict): - pass - - def _build_optimizer_states(self): - pass - - def _set_ext_emb_size(self): - self._ext_emb_size = self._emb_size * self._ext_coefficient - logger.debug("init table, ext_emb_size is set to be %s.", self._ext_emb_size) - def _set_slice_vocab_size(self): # 动态扩容模式下,保留device侧variable,大小设置为1 self._slice_device_vocabulary_size = 1 - def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: - channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) - config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, - rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, - is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, - emb_size=self._emb_size, device_id=self._device_id, use_dynamic_expansion=True) - - return get_preprocessed_tensor_for_asc(self._variable, config) diff --git a/mx_rec/core/emb/sparse_embedding.py b/mx_rec/core/emb/sparse_embedding.py index d8ce63b1435eec826c3cf0ae3d0fbf1b187720fc..39af9d60a2ae2bf9722d5dc07f839cee44645ddd 100644 --- a/mx_rec/core/emb/sparse_embedding.py +++ b/mx_rec/core/emb/sparse_embedding.py @@ -11,10 +11,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from mx_rec.core.asc.feature_spec import FeatureSpec -from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding -from mx_rec.optimizers.emb_optimizer import EmbOptimizer -from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger @@ -30,22 +27,6 @@ class SparseEmbedding(BaseSparseEmbedding): def capacity(self) -> int: pass - @abc.abstractmethod - def set_optimizer(self, key: str, state_dict: dict): - pass - - @abc.abstractmethod - def _set_ext_emb_size(self): - pass - - @abc.abstractmethod - def _build_optimizer_states(self): - pass - - @abc.abstractmethod - def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: - pass - def _set_slice_vocab_size(self): self._slice_device_vocabulary_size = math.ceil(self._device_vocabulary_size / self._rank_size) self._slice_host_vocabulary_size = math.ceil(self._host_vocabulary_size / self._rank_size) @@ -53,11 +34,8 @@ class SparseEmbedding(BaseSparseEmbedding): def _get_update_grad(self, local_grad: tf.Tensor, result: dict, table: Union[tf.compat.v1.Variable, tf.Tensor]) -> Union[tf.IndexedSlices, tf.Tensor]: - unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - result.get("restore_vector_second"), - array_ops.shape(result.get("unique_keys"))[0]) - return ops.IndexedSlices(values=unique_local_grad, - indices=result.get("unique_keys"), + return ops.IndexedSlices(values=local_grad, + indices=result.get("id_offsets"), dense_shape=tf.shape(table)) def _get_local_embeddings(self, table: Union[tf.compat.v1.Variable, tf.Tensor], result: dict, @@ -87,25 +65,6 @@ class HBMSparseEmbedding(SparseEmbedding): def capacity(self) -> int: return self._device_vocabulary_size - def set_optimizer(self, key: str, state_dict: dict): - pass - - def _build_optimizer_states(self): - pass - - def _set_ext_emb_size(self): - self._ext_emb_size = self._emb_size * self._ext_coefficient - logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) - - def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: - channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) - config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, - rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, - is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, - emb_size=self._emb_size, device_id=self._device_id) - - return get_preprocessed_tensor_for_asc(self._variable, config) - class ExternalStorageSparseEmbedding(SparseEmbedding): """ @@ -113,52 +72,14 @@ class ExternalStorageSparseEmbedding(SparseEmbedding): """ def __init__(self, config: dict): - self.emb_optimizer = EmbOptimizer(config.get("optimizer_list")) - self.emb_optimizer.check_optimizer_instance_list() - super(ExternalStorageSparseEmbedding, self).__init__(config) - @property - def optimizer(self): - return self.emb_optimizer.optimizer - - @property - def optimizer_instance_list(self): - return self.emb_optimizer.optimizer_instance_list - def capacity(self) -> int: # DDR if not self._ssd_vocabulary_size: - return self._device_vocabulary_size + self._host_vocabulary_size + return self._host_vocabulary_size # SSD - return self._device_vocabulary_size + self._host_vocabulary_size + self._ssd_vocabulary_size - - def set_optimizer(self, key: str, state_dict: dict): - self.emb_optimizer.set_optimizer(key, state_dict, self._table_name) - - def _set_ext_emb_size(self): - self._ext_coefficient += len(self.emb_optimizer.optimizer_slot_info_list) - self._ext_emb_size = self._emb_size * self._ext_coefficient - logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) - - def _build_optimizer_states(self): - for sparse_optimizer_instance in self.emb_optimizer.optimizer_instance_list: - slot_info_list = sparse_optimizer_instance.initialize_slots(self._variable, self) - self.emb_optimizer.optimizer_slot_info_list.extend(slot_info_list) - - for slot_info in self.emb_optimizer.optimizer_slot_info_list: - self.emb_optimizer.set_optimizer_slot(slot_info) - - def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: - channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) - config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, - rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, - is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, - emb_size=self._emb_size, device_id=self._device_id) - - variable_list = [self._variable] + \ - [slot_info.get("slot") for slot_info in self.emb_optimizer.optimizer_slot_info_list] - return get_preprocessed_tensor_for_asc(variable_list, config) + return self._host_vocabulary_size + self._ssd_vocabulary_size def _set_specific_value_for_non_valid_key(id_offsets: Optional[tf.Tensor], diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index b38c486b69816815afa71771bb214c926b55aca8..23eb86aab982bdf4aabecdc4cf308a5a7e1e5fc0 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -16,18 +16,22 @@ # ============================================================================== import os -from typing import Optional, Union +from typing import Optional, Union, List import tensorflow as tf +from tensorflow import Tensor from tensorflow.python.ops.init_ops import Initializer as InitializerV1 from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 +from mx_rec.constants import constants from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding from mx_rec.core.emb.emb_factory import HBMDynamicSparseEmbeddingFactory, HBMSparseEmbeddingFactory, \ ExternalStorageSparseEmbeddingFactory -from mx_rec.graph.utils import tag_orphan_ids -from mx_rec.constants.constants import MAX_INT32, All2allGradientsOp, MAX_VOCABULARY_SIZE, MAX_DEVICE_VOCABULARY_SIZE +from mx_rec.constants.constants import (MAX_INT32, All2allGradientsOp, MAX_VOCABULARY_SIZE, MAX_DEVICE_VOCABULARY_SIZE, + CacheModeEnum, DEFAULT_DEVICE_CACHE_MEMORY_SIZE, DEFAULT_HOST_CACHE_MEMORY_SIZE) +from mx_rec.graph.constants import AnchorIteratorOp +from mx_rec.util.communication.hccl_ops import get_rank_size from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, \ @@ -43,27 +47,25 @@ from mx_rec.util.log import logger ("dim", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), ("name", StringValidator, {"min_len": 1, "max_len": 100}, ["check_string_length", "check_whitelist"]), ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), - ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator), ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_DEVICE_VOCABULARY_SIZE}, ["check_value"]), ("host_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_data_path", ClassValidator, {"classes": (list, tuple)}), - ("is_save", ClassValidator, {"classes": (bool, )}), + ("is_save", ClassValidator, {"classes": (bool,)}), ("init_param", FloatValidator, {"min_value": -10, "max_value": 10}, ["check_value"]), ("all2all_gradients_op", OptionValidator, {"options": [i.value for i in list(All2allGradientsOp)]}), ("value_dtype", OptionValidator, {"options": [tf.float32]}), ("shard_num", IntValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), - ("fusion_optimizer_var", ClassValidator, {"classes": (bool, )}), + ("fusion_optimizer_var", ClassValidator, {"classes": (bool,)}), ("hashtable_threshold", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]) ]) def create_table(key_dtype, dim, name, emb_initializer, - optimizer_list: Optional[list] = None, device_vocabulary_size=1, host_vocabulary_size=0, ssd_vocabulary_size=0, - ssd_data_path=(os.getcwd(), ), + ssd_data_path=(os.getcwd(),), is_save=True, init_param=1., all2all_gradients_op=All2allGradientsOp.SUM_GRADIENTS.value, @@ -77,7 +79,6 @@ def create_table(key_dtype, dim, name, emb_initializer, dim: embedding vector size name: hash table name emb_initializer: the initializer for embedding values - optimizer_list: specify the optimizers to use for current hash table device_vocabulary_size: embedding vector numbers on device host_vocabulary_size: embedding vector numbers on ddr ssd_vocabulary_size: embedding vector numbers on ssd @@ -92,25 +93,28 @@ def create_table(key_dtype, dim, name, emb_initializer, """ name = fix_invalid_table_name(name) + dim_bytes = dim.as_list()[0] * 4 if isinstance(dim, tf.TensorShape) else dim * 4 # float32 4 bytes + voc_size_list = [device_vocabulary_size, host_vocabulary_size, ssd_vocabulary_size] + check_and_set_default_voc_size(voc_size_list, dim_bytes) + config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, - device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, - ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, - optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, - all2all_gradients_op=all2all_gradients_op) + device_vocabulary_size=voc_size_list[0], host_vocabulary_size=voc_size_list[1], + ssd_vocabulary_size=voc_size_list[2], ssd_data_path=ssd_data_path, + init_param=init_param, is_save=is_save, all2all_gradients_op=all2all_gradients_op) # 动态扩容 if ConfigInitializer.get_instance().use_dynamic_expansion: return HBMDynamicSparseEmbeddingFactory().create_embedding(config) # DDR or SSD - if host_vocabulary_size > 0: + if voc_size_list[1] > 0: return ExternalStorageSparseEmbeddingFactory().create_embedding(config) # HBM return HBMSparseEmbeddingFactory().create_embedding(config) @para_checker_decorator(check_option_list=[ - ("hashtable", ClassValidator, {"classes": (BaseSparseEmbedding, )}), + ("hashtable", ClassValidator, {"classes": (BaseSparseEmbedding,)}), ("ids", ClassValidator, {"classes": (FeatureSpec, tf.Tensor)}), - ("is_train", ClassValidator, {"classes": (bool, )}), + ("is_train", ClassValidator, {"classes": (bool,)}), ("send_count", ClassValidator, {"classes": (int, type(None))}), ("send_count", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("name", ClassValidator, {"classes": (str, type(None))}), @@ -118,7 +122,7 @@ def create_table(key_dtype, dim, name, emb_initializer, ("modify_graph", ClassValidator, {"classes": (bool, type(None))}), ("batch", ClassValidator, {"classes": (dict, type(None))}), ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), - ("is_grad", ClassValidator, {"classes": (bool, )}), + ("is_grad", ClassValidator, {"classes": (bool,)}), ("serving_default_value", ClassValidator, {"classes": (tf.Tensor, type(None))}) ]) def sparse_lookup(hashtable: BaseSparseEmbedding, @@ -172,7 +176,7 @@ def sparse_lookup(hashtable: BaseSparseEmbedding, # 对于向上找没有IteratorGetNext的孤儿ids需要标记,以便于后续ACGPushOpsToDataset工作 if isinstance(ids, tf.Tensor): - ids = tag_orphan_ids(ids) + ids = mark_orphan_lookup_key(ids) with tf.compat.v1.variable_scope("{0}//{1}".format(hashtable.table_name, kwargs.get("name"))): if isinstance(ids, FeatureSpec): @@ -188,3 +192,52 @@ def sparse_lookup(hashtable: BaseSparseEmbedding, ConfigInitializer.get_instance().modify_graph = modify_graph return hashtable.lookup(ids, send_count, **kwargs) + + +def mark_orphan_lookup_key(lookup_key: Tensor) -> Tensor: + graph_def = tf.compat.v1.get_default_graph().as_graph_def() + subgraph = tf.compat.v1.graph_util.extract_sub_graph(graph_def, [lookup_key.op.name]) + + for node in subgraph.node: + if node.op == AnchorIteratorOp.ITERATOR_GET_NEXT.value: + return lookup_key + + name_prefix = constants.ORPHAN_LOOKUP_KEY_PREFIX + marked_lookup_key = tf.identity(lookup_key, name="{}/{}".format(name_prefix, lookup_key.op.name)) + + logger.info('Mark orphan lookup key %s as %s.', lookup_key, marked_lookup_key) + return marked_lookup_key + + +def check_and_set_default_voc_size(voc_size_list: List[int], dim_bytes: int): + if ConfigInitializer.get_instance().use_dynamic_expansion: + voc_size_list[1] = 0 + voc_size_list[2] = 0 + return + cache_mode = os.getenv("CACHE_MODE") + if not cache_mode and voc_size_list[0] <= 1: + raise ValueError("no cache mode, no use_dynamic_expansion, must input dev-voc") + if not cache_mode and voc_size_list[1] == 0 and voc_size_list[2] == 0: # no cache mode, dev-voc not None, use HBM + return + if not cache_mode and voc_size_list[1] == 0 and voc_size_list[2] > 0: + raise ValueError("no cache mode, dev-voc is not none and host-voc is none, ssd-voc must be none too") + if not cache_mode and voc_size_list[2] == 0: # no cache mode, dev-voc/host-voc not None, use DDR + return + if not cache_mode: # no cache mode, dev-voc/host-voc/ssd-voc not None, use SSD + return + + if cache_mode not in [mode.value for mode in CacheModeEnum]: + raise ValueError("cache mode need to fit HBM, DDR, SSD") + if cache_mode == CacheModeEnum.HBM.value and (voc_size_list[1] > 0 or voc_size_list[2] > 0): + raise ValueError("cache mode HBM, host-voc or ssd-voc is need to be none") + if cache_mode == CacheModeEnum.DDR.value and voc_size_list[2] > 0: + raise ValueError("cache mode DDR, ssd-voc is need to be none") + if voc_size_list[0] == 1: + default_device_voc_size = int(DEFAULT_DEVICE_CACHE_MEMORY_SIZE / dim_bytes * get_rank_size()) # single rank 2GB + voc_size_list[0] = min(default_device_voc_size, MAX_DEVICE_VOCABULARY_SIZE) + if (cache_mode == CacheModeEnum.DDR.value or cache_mode == CacheModeEnum.SSD.value) and voc_size_list[1] == 0: + default_host_voc_size = int(DEFAULT_HOST_CACHE_MEMORY_SIZE / dim_bytes) # total 40GB + voc_size_list[1] = min(default_host_voc_size, MAX_VOCABULARY_SIZE) + if cache_mode == CacheModeEnum.SSD.value and voc_size_list[2] == 0: + voc_size_list[2] = MAX_VOCABULARY_SIZE + return diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 3963f6d56e27072df23b14c1174790c25979b5b5..a2161d026976109bcdc78b2f8f68aa6815778aad 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -50,9 +50,9 @@ class EvictHook(tf.compat.v1.train.SessionRunHook): self._global_step_tensor = None if evict_step_interval is None: - logger.info(f"_EvictHook - > evict_time_interval: %d", self._evict_time_interval) + logger.info("_EvictHook - > evict_time_interval: %d", self._evict_time_interval) else: - logger.info(f"_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", + logger.info("_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", self._evict_time_interval, self._evict_step_interval) def begin(self): @@ -61,6 +61,8 @@ class EvictHook(tf.compat.v1.train.SessionRunHook): raise RuntimeError("Global step should be created to use _EvictHook.") self.check_name_and_get_hashtable() for name, instance in self._hash_table_instance.items(): + if not instance.is_hbm: + continue scope_name = f"{instance.table_name}//evict" with tf.compat.v1.variable_scope(scope_name): logger.debug('Channel %s_evict_%d was built for op getnext', instance.table_name, TRAIN_CHANNEL_ID) @@ -99,7 +101,9 @@ class EvictHook(tf.compat.v1.train.SessionRunHook): if not ConfigInitializer.get_instance().hybrid_manager_config.trigger_evict(): return self._start_time = cur_time - for name in self._hash_table_instance.keys(): + for name, instance in self._hash_table_instance.items(): + if not instance.is_hbm: + continue run_context.session.run(self._evict_op.get(name)) def check_name_and_get_hashtable(self): diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py index f4d2642c1de1e8ca6ea4ff629d86327b762140d6..f14659714a881926a1c3becceb067cae55aa42ff 100644 --- a/mx_rec/graph/__init__.py +++ b/mx_rec/graph/__init__.py @@ -15,8 +15,12 @@ # limitations under the License. # ============================================================================== -__all__ = ["modify_graph_and_start_emb_cache", "GraphModifierHook", "run", "ACGPushOpsToDatasetHook"] +__all__ = [ + "GraphModifierHook", + "LookupSubgraphSlicerHook", + "OrphanLookupKeySlicerHook", + "modify_graph_and_start_emb_cache", +] from mx_rec.graph.modifier import GraphModifierHook, modify_graph_and_start_emb_cache -from mx_rec.graph.patch import run -from mx_rec.graph.acg_push_ops import ACGPushOpsToDatasetHook +from mx_rec.graph.hooks import LookupSubgraphSlicerHook, OrphanLookupKeySlicerHook diff --git a/mx_rec/graph/acg_push_ops.py b/mx_rec/graph/acg_push_ops.py deleted file mode 100644 index 625ef92fc4f68988a25e5a01940a6c080a5d1680..0000000000000000000000000000000000000000 --- a/mx_rec/graph/acg_push_ops.py +++ /dev/null @@ -1,641 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Dict, Tuple, List, Set - -import tensorflow as tf -from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter -from tensorflow.python.framework.ops import Operation -from tensorflow.python.util import nest as tf_nest -from tensorflow.core.framework import node_def_pb2 -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.framework import tensor_util - -from mx_rec.graph import modifier -from mx_rec.util.log import logger -from mx_rec.graph.utils import export_pb_graph -from mx_rec.graph.graph_typing import SubgraphInfo -from mx_rec.constants.constants import ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE, AnchorIteratorOp -from mx_rec.validator.validator import para_checker_decorator, ClassValidator - -tf.compat.v1.disable_eager_execution() - -_ACG_NEW_NODE_PREFIX = "ACG_" -_ACG_NEW_ITERATOR = "ACG_NEW_ITERATOR" -_ACG_NEW_INITIALIZER = "ACG_NEW_INITIALIZER" - -_OP_TYPE_TO_PUSH = frozenset(["StringSplit", "StringToNumber"]) -_OP_TYPE_TO_IGNORE = frozenset([AnchorIteratorOp.ITERATOR_GET_NEXT]) -_OP_TYPE_CONTAIN_STRING_TO_IGNORE = frozenset(["Dataset", "Summary"]) -_OP_NAME_CONTAIN_STRING_TO_IGNORE = frozenset(["save", "report_", "loss"]) -_OP_NAME_CONTAIN_STRING_TO_PUSH = frozenset(["ACG_PUSH_NODE"]) - -_TENSOR_TYPE_TO_IGNORE = frozenset([tf.variant, tf.resource]) - -_VARIABLE_TYPES = frozenset(["Variable", "VariableV2", "VarHandleOp"]) -_IGNORE_REPLACE_NODE = frozenset(["Assign", "SaveV2"]) - - -class ACGPushOpsToDatasetHook(tf.estimator.SessionRunHook): - @para_checker_decorator( - check_option_list=[ - ("dump_graph", ClassValidator, {"classes": (bool,)}), - ] - ) - def __init__(self, dump_graph: bool = False) -> None: - super().__init__() - self._dump_graph = dump_graph - - modifier.get_src_dataset = _patched_get_src_dataset - logger.info("[ACGPushOpsToDatasetHook] The function `get_src_dataset` of modifier has been replaced!") - - def begin(self): - logger.info("[ACGPushOpsToDataset] Trigger at beginning!") - graph = tf.compat.v1.get_default_graph() - _find_ops_to_be_pushed(graph=graph, dump_graph=self._dump_graph) - - def after_create_session(self, session, coord): - logger.info("[ACGPushOpsToDatasetHook] Trigger after create session!") - initializers = tf.compat.v1.get_collection(_ACG_NEW_INITIALIZER) - logger.info(f"[ACGPushOpsToDatasetHook] Got new initialzers: %s.", initializers) - session.run(initializers) - - def end(self, session): - logger.info("[ACGPushOpsToDatasetHook] Trigger in the end!") - - -def _find_ops_to_be_pushed(graph: tf.Graph, dump_graph: bool = False): - export_pb_graph("before_push_graph.pbtxt", dump_graph, graph_def=graph.as_graph_def()) - op_nodes = graph.get_operations() - nodes_to_push = set() - - for op_node in op_nodes: - if op_node.type in _OP_TYPE_TO_IGNORE: - continue - - pushable = False - if op_node.type in _OP_TYPE_TO_PUSH: - pushable = True - - for ignore_type in _OP_TYPE_CONTAIN_STRING_TO_IGNORE: - if ignore_type in op_node.type: - pushable = False - if not pushable: - continue - for ignore_name in _OP_NAME_CONTAIN_STRING_TO_IGNORE: - if ignore_name in op_node.name: - pushable = False - if not pushable: - continue - for each_tensor in list(op_node.outputs) + list(op_node.inputs): - if each_tensor.dtype in _TENSOR_TYPE_TO_IGNORE: - pushable = False - if not pushable: - continue - - for push_name in _OP_NAME_CONTAIN_STRING_TO_PUSH: - if push_name in op_node.name: - pushable = True - break - - if pushable: - nodes_to_push.add(op_node) - - if not nodes_to_push: - logger.info("No target op has to be pushed to dataset map func!") - return - - logger.info("Found operations should be pushed: %s.", nodes_to_push) - subgraph_nodes = _find_subgraph_nodes( - graph, nodes_to_push, tgt_op_type=AnchorIteratorOp.ITERATOR_GET_NEXT.value, exclude_tgt_op=True - ) - _push_subgraph_to_dataset(graph, subgraph_nodes, dump_graph) - export_pb_graph("after_push_graph.pbtxt", dump_graph, graph_def=graph.as_graph_def()) - - -def _find_subgraph_nodes( - graph: tf.Graph, - base_nodes: Set[tf.Operation], - tgt_op_type: str, - exclude_tgt_op: bool = True, -) -> Set[tf.Operation]: - subgraph_nodes = set() - visited_nodes = base_nodes - found_nodes = base_nodes - all_nodes = graph.get_operations() - logger.info("Got base_nodes: %s.", base_nodes) - - loop_cnt = 0 - while len(found_nodes) > 0: - loop_cnt += 1 - if loop_cnt > MAX_WHILE_SIZE: - raise RuntimeError(f"In bfs_lookup function, the maximum cycle depth is greater than {MAX_WHILE_SIZE}.") - - base_nodes = set() - for parent_node in found_nodes: - if (not exclude_tgt_op) and parent_node.type == tgt_op_type: - continue - base_nodes.add(parent_node) - found_nodes = set() - for base_node in base_nodes: - tmp_nodes = [x.op for x in base_node.inputs] + base_node.control_inputs - _warn_for_var_scope_nodes(all_nodes, base_node) - - tmp_nodes = set(tmp_nodes) - visited_nodes - if exclude_tgt_op: - tmp_nodes = set(filter(lambda node: node.type != tgt_op_type, tmp_nodes)) - found_nodes.update(tmp_nodes) - visited_nodes.update(tmp_nodes) - - subgraph_nodes.update(visited_nodes) - logger.info("Found subgraph from nodes_to_push: %s.", subgraph_nodes) - return subgraph_nodes - - -def _warn_for_var_scope_nodes(all_nodes: List[tf.Operation], base_node: tf.Operation): - if base_node.type in _VARIABLE_TYPES: - for x in base_node.outputs: - varable_scope_node = [x for x in all_nodes if x.name.startswith(f"{base_node.name}/")] - logger.warning("Got base_node: %s and varable_scope_node: %s.", base_node, varable_scope_node) - - -def _find_op_from_base_op(base_ops: tf.Operation, target_op_type: str) -> tf.Operation: - base_ops = modifier.check_input_list(base_ops, tf.Operation) - parent_ops = base_ops - while True: - for parent_op in parent_ops: - if parent_op.type == target_op_type: - return parent_op - base_ops = parent_ops - parent_ops = [] - for base_op in base_ops: - parent_ops.extend(modifier.find_parent_op(base_op)) - if not parent_ops: - raise ValueError(f"Op {target_op_type} was not found.") - - -def _get_dataset_op(graph: tf.Graph, get_next_op: Operation) -> Operation: - if get_next_op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: - raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") - # looking for the MakeIterator operator which corresponds to given batch_tensor - base_op = modifier.find_make_iterator_op(get_next_op.outputs[0]) - # looking for the op which is the one before OptimizeDataset operator - if tf.__version__.startswith("1"): - optimize_dataset_op = _find_op_from_base_op(base_op, "ModelDataset") - target_op = modifier.find_parent_op(optimize_dataset_op) - if not target_op: - raise RuntimeError(f"The parent op for 'ModelDataset' op was not found.") - if target_op[0].type != "OptimizeDataset": - raise TypeError(f"Op OptimizeDataset was not found.") - target_op = target_op[0] - else: - # 'OptimizeDataset' is not available in TensorFlow2.X - raise RuntimeError("Not supoprt tf2") - return target_op - - -def _ordered_output_from_subgraph(subgraph_out: Dict[tf.Operation, Set[tf.Operation]]) -> List[tf.Tensor]: - addition_funcgraph_output_tensor = [] - for k, v in sorted(subgraph_out.items(), key=lambda x: x[0].name): - k_inputs = set(k.inputs) - for node in v: - _add_sorted_additional_tensors(addition_funcgraph_output_tensor, k_inputs, node) - return addition_funcgraph_output_tensor - - -def _add_sorted_additional_tensors(addition_funcgraph_output_tensor, k_inputs, node): - for each_tensor in sorted(node.outputs, key=lambda x: x.name): - if each_tensor in k_inputs: - addition_funcgraph_output_tensor.append(each_tensor) - - -def _get_tensor_consumers_unsafe(tensor: tf.Tensor) -> List[tf.Operation]: - if isinstance(tensor, tf.Operation): - raise RuntimeError("not support type: {node}") - - from tensorflow.python import pywrap_tensorflow as c_api - - consumer_names = c_api.TF_OperationOutputConsumers_wrapper(tensor._as_tf_output()) - graph = tensor.graph - result = [] - for name in consumer_names: - with graph._lock: - if name in graph._nodes_by_name: # ignore deleted node - result.append(graph._nodes_by_name[name]) - - return result - - -def _push_subgraph_to_dataset(graph: tf.Graph, subgraph_to_push: Set[tf.Operation], dump_graph: bool = False): - subgraph_in, subgraph_out = _find_subgraph_in_out(subgraph_to_push) - logger.info("Got input tensor of extracted subgraph: %s", subgraph_in) - logger.info("Got output tensor of extracted subgraph: %s", subgraph_out) - - get_next_node = graph.get_operation_by_name(AnchorIteratorOp.ITERATOR_GET_NEXT.value) - src_dataset = _get_src_dataset(graph, get_next_node) - - def acg_func(*x): # pragma: no cover - old_x = x - logger.debug("Got old batch layout: %s", x) - - x = tf_nest.flatten(x) - for each_tensor in x: - if not isinstance(each_tensor, tf.Tensor): - raise RuntimeError(f"Expected tensor as input of mapfunc. but got: {x}!") - - funcgraph = tf.compat.v1.get_default_graph() - subgraph_info = SubgraphInfo(subgraph_in, subgraph_out, subgraph_to_push) - new_batch = _clone_subgraph_into_funcgraph( - funcgraph, - graph, - subgraph_info, - x, - old_x, - ) - - logger.debug("Got new batch layout: %s.", new_batch) - export_pb_graph("map_func_graph.pbtxt", dump_graph, graph_def=funcgraph.as_graph_def()) - return new_batch - - tgt_dataset = src_dataset.map(acg_func) - tgt_dataset = tgt_dataset.prefetch(0) - _update_iterator_getnext( - graph=graph, - get_next_op=get_next_node, - tgt_dataset=tgt_dataset, - subgraph_out=subgraph_out, - subgraph_to_push=subgraph_to_push, - ) - - -def _find_subgraph_in_out( - sub_graph_nodes: Set[tf.Operation], -) -> Tuple[Dict[tf.Operation, Set[tf.Operation]], Dict[tf.Operation, Set[tf.Operation]]]: - relay_input_nodes = set() - relay_output_nodes = set() - input_to_subnodes = dict() - output_to_subnodes = dict() - - for base_node in sub_graph_nodes: - _update_subgraph_in(base_node, input_to_subnodes, relay_input_nodes, sub_graph_nodes) - _update_subgraph_out(base_node, output_to_subnodes, relay_output_nodes, sub_graph_nodes) - - return input_to_subnodes, output_to_subnodes - - -def _update_subgraph_in( - base_node: tf.Operation, - input_to_subnodes: Dict[tf.Operation, Set[tf.Operation]], - relay_input_nodes: Set[tf.Operation], - sub_graph_nodes: Set[tf.Operation], -): - for input_tensor in base_node.inputs: - input_node = input_tensor.op - if input_node not in sub_graph_nodes: - relay_input_nodes.add(input_node) - res = input_to_subnodes.get(input_node, set()) - res.add(base_node) - input_to_subnodes[input_node] = res - - -def _update_subgraph_out( - base_node: tf.Operation, - output_to_subnodes: Dict[tf.Operation, Set[tf.Operation]], - relay_output_nodes: Set[tf.Operation], - sub_graph_nodes: Set[tf.Operation], -): - for output_tensor in base_node.outputs: - for output_consumer in output_tensor.consumers(): - if output_consumer not in sub_graph_nodes: - relay_output_nodes.add(output_consumer) - res = output_to_subnodes.get(output_consumer, set()) - res.add(base_node) - output_to_subnodes[output_consumer] = res - - -def _get_src_dataset(graph: tf.Graph, get_next_op: Operation) -> DatasetV1Adapter: - try: - target_op = _get_dataset_op(graph, get_next_op) - except (ValueError, TypeError, RuntimeError) as err: - logger.warning("The dataset op was not found, the error is %s. Start to traverse the operations.", err) - dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] - if len(dataset_op_list) != 1: - raise RuntimeError( - f"The `{ANCHOR_DATASET_NAME}` was not found from the operations, dataset_op_list: " - f"{dataset_op_list}." - ) from err - target_op = dataset_op_list[0] - except Exception as err: - raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err - if not target_op.outputs: - raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") - logger.info("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) - src_dataset = modifier.find_target_instance_dataset(target_op.outputs[0]) - return src_dataset - - -def _clone_subgraph_into_funcgraph( - funcgraph: tf.Graph, - defaultgraph: tf.Graph, - subgraph_info: SubgraphInfo, - x: List[tf.Tensor], - old_x: Tuple[Dict[str, tf.Tensor]], -) -> Dict[str, tf.Tensor]: - topo_subgraph_list = _topo_subgraph(subgraph_info.subgraph_to_push) # node - tensor_mapping = {} # subgraph-tensor -> funcgraph-tensor - node_mapping = {} # subgraph-node -> funcgraph-node - for k, v in subgraph_info.subgraph_in.items(): - _get_mapping_for_subgraph_in(k, v, x, tensor_mapping) - for old_node in topo_subgraph_list: - _get_mapping_for_subgraph(funcgraph, defaultgraph, node_mapping, old_node, tensor_mapping) - - logger.info("Got node_mapping: %s", node_mapping) - logger.info("Got tensor_mapping: %s", tensor_mapping) - - ordered_output_subgraph_tensors = _ordered_output_from_subgraph(subgraph_info.subgraph_out) - addition_funcgraph_output_tensor = _get_mapping_tensor(tensor_mapping, ordered_output_subgraph_tensors) - new_funcgraph_output_tensor = list(x) + addition_funcgraph_output_tensor - logger.info("Got new_funcgraph_output_tensor: %s", new_funcgraph_output_tensor) - - new_x = old_x[0] - for tensor in addition_funcgraph_output_tensor: - last_key = f"{sorted(new_x)[-1]}_last_key" - new_x[last_key] = tensor - - return new_x - - -def _get_mapping_for_subgraph_in( - from_node: tf.Operation, to_nodes: Set[tf.Operation], x: List[tf.Tensor], tensor_mapping -): - if from_node.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: - raise RuntimeError(f"Expect IteratorGetNext for input tensor of subgraph, but got {from_node}") - for node in to_nodes: - for each_tensor in node.inputs: - if each_tensor.op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: - continue - old_tensor_name = each_tensor.name - x_index = int(old_tensor_name.split(":")[-1]) - tensor_mapping[each_tensor] = x[x_index] - - -def _get_mapping_for_subgraph( - funcgraph: tf.Graph, - defaultgraph: tf.Graph, - node_mapping: Dict[tf.Operation, tf.Operation], - old_node: tf.Operation, - tensor_mapping: Dict[tf.Tensor, tf.Tensor], -): - logger.debug("old_node: %s \n old_node_inputs: %s", old_node, [x for x in old_node.inputs]) - node_def = old_node.node_def - for each_tensor in old_node.inputs: - if each_tensor not in tensor_mapping: - raise RuntimeError( - f"each_tensor(input) {each_tensor} need by {old_node.name} not in tensor_mapping.{tensor_mapping}" - ) - new_inputs = _get_mapping_tensor(tensor_mapping, old_node.inputs) - if old_node.type in _VARIABLE_TYPES: - node_def = _frozen_variable_node_to_func_const_node_def( - variable_node=old_node, funcgraph=funcgraph, defaultgraph=defaultgraph - ) - node_def.name = _ACG_NEW_NODE_PREFIX + node_def.name - new_node = tf.Operation(node_def=node_def, g=funcgraph, inputs=new_inputs) - node_mapping[old_node] = new_node - for old_out_tensor, new_out_tensor in zip(old_node.outputs, new_node.outputs): - tensor_mapping[old_out_tensor] = new_out_tensor - - -def _frozen_variable_node_to_func_const_node_def( - variable_node: tf.Operation, funcgraph: tf.Graph, defaultgraph: tf.Graph -) -> node_def_pb2.NodeDef: - def create_const_node_def(node_name, dtype, data, data_shape=None): - """Creates a Const op.""" - output_node = node_def_pb2.NodeDef() - output_node.op = "Const" - output_node.name = node_name - output_node.attr["dtype"].CopyFrom(dtype) - output_node.attr["value"].CopyFrom( - attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data_shape)) - ) - return output_node - - # NOTE: Variable node type is readonly in funcgraph, all nodes of this type have to be fronzen. - variable_name = variable_node.name - if variable_node.type == "VarHandleOp": - variable_name = f"{variable_name}/Read/ReadVariableOp:0" - else: - variable_name = f"{variable_name}:0" - initializer = defaultgraph.get_operation_by_name(f"{variable_node.name}/Assign") - logger.info(f"VariableV2: {variable_node.name}, initializer: {initializer.name} ") - defaultsession = tf.compat.v1.Session(graph=defaultgraph) - _ = defaultsession.run([initializer]) - logger.info(f"Start run variables data: {variable_name}") - returned_variable_data = defaultsession.run(variable_name) - logger.info(f"Start froze variables: {variable_name} {returned_variable_data}") - new_const_node = create_const_node_def( - variable_node.name, variable_node.node_def.attr["dtype"], returned_variable_data, returned_variable_data.shape - ) - return new_const_node - - -def _get_mapping_tensor(tsr2tsr: Dict[tf.Tensor, tf.Tensor], keys: List[tf.Tensor]) -> List[tf.Tensor]: - tensors = [] - for k in keys: - if k not in tsr2tsr: - raise KeyError(f"Failed to find key tensor: {k} from tensor map: {tsr2tsr}.") - tensors.append(tsr2tsr[k]) - return tensors - - -def _topo_subgraph(subgraph: Set[tf.Operation]) -> List[tf.Operation]: - topo_subgraph_list = [] - topo_subgraph_set = set() - start_nodes = set() - [start_nodes.add(x) for x in subgraph] - logger.info("Got topo_subgraph start nodes: %s", start_nodes) - - def topo_subgraph_dfs(curr_node, output_list, output_set): - if not isinstance(curr_node, tf.Operation): - raise RuntimeError(f"topo_subgraph_dfs input should be node(aka. tf.Operator). {curr_node}") - curr_inputs = curr_node.inputs - logger.debug("Got topo_dfs: %s <- %s", curr_node.name, [x.name for x in curr_inputs]) - current_control_inputs = curr_node.control_inputs - if len(current_control_inputs) > 0: - raise RuntimeError( - f"Control input are not supported: {curr_node.name}, control_inputs: {current_control_inputs}" - ) - if curr_node in output_set: - return - output_set.add(curr_node) - for tensor in curr_inputs: - node = tensor.op - if node.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value and node not in output_set: - topo_subgraph_dfs(node, output_list, output_set) - output_list.append(curr_node) - - [topo_subgraph_dfs(x, topo_subgraph_list, topo_subgraph_set) for x in start_nodes] - if len(topo_subgraph_list) != len(topo_subgraph_set): - raise RuntimeError(f"Got duplicated topo node: {sorted(topo_subgraph_list, key=lambda x: x.name)}.") - logger.info("Got topo_subgraph: %s", topo_subgraph_list) - return topo_subgraph_list - - -def _update_iterator_getnext( - graph: tf.Graph, - get_next_op: Operation, - tgt_dataset: DatasetV1Adapter, - subgraph_out: Dict[tf.Operation, Set[tf.Operation]], - subgraph_to_push: Set[tf.Operation], -): - if not get_next_op.outputs: - raise RuntimeError("There is no tensor in the dataset. Please check the dataset and data processing.") - iterator_type = "" - if get_next_op.inputs: - iterator_type = get_next_op.inputs[0].op.type - if iterator_type == "IteratorV2": - iterator_type = modifier.find_make_iterator_op(get_next_op.outputs[0]).type - if iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value): - raise RuntimeError( - f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " - f"but the current iterator is `{iterator_type}`." - ) - logger.info("The iterator type of dataset is %s.", iterator_type) - if iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: - new_iterator = tgt_dataset.make_initializable_iterator() - logger.info("Got new_iterator: %s, new_iterator.initializer: %s.", new_iterator, new_iterator.initializer) - graph.add_to_collection(_ACG_NEW_INITIALIZER, new_iterator.initializer) - else: - new_iterator = tgt_dataset.make_one_shot_iterator() - new_batch = new_iterator.get_next(_ACG_NEW_ITERATOR) - if "timestamp" in new_batch.keys(): - tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, new_batch["timestamp"]) - try: - new_batch_tensor = new_batch - while not isinstance(new_batch_tensor, tf.Tensor): - if isinstance(new_batch_tensor, tuple): - new_batch_tensor = new_batch_tensor[0] - elif isinstance(new_batch_tensor, dict): - new_batch_tensor = list(new_batch_tensor.values()) - elif isinstance(new_batch_tensor, list): - new_batch_tensor = new_batch_tensor[0] - elif isinstance(new_batch_tensor, tf.Tensor): - break - else: - raise RuntimeError( - f"Need to support new_batch_tensor{new_batch_tensor}, type: {type(new_batch_tensor)}" - ) - except IndexError as err: - raise IndexError("Cannot find a tensor from given batch.") from err - new_get_next_op = _find_op_from_base_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value) - logger.info("Got new_get_next_op: %s.", new_get_next_op) - _replace_get_next_op(graph, get_next_op, new_get_next_op, subgraph_out, subgraph_to_push) - - -def _replace_get_next_op( - graph: tf.Graph, - old_get_next_op: tf.Operation, - new_get_next_op: tf.Operation, - subgraph_out: Dict[tf.Operation, Set[tf.Operation]], - subgraph_to_push: Set[tf.Operation], -): - for output_tensor in old_get_next_op.outputs: - _update_old_consumer(graph, new_get_next_op, output_tensor, subgraph_to_push) - - old_get_next_op_output_size = len(old_get_next_op.outputs) - ordered_output_tensor = _ordered_output_from_subgraph(subgraph_out) - - for i, output_tensor in enumerate(ordered_output_tensor): - offset = old_get_next_op_output_size + i - _update_subgraph_out_consumer(graph, new_get_next_op, offset, output_tensor) - - -def _update_old_consumer( - graph: tf.Graph, new_get_next_op: tf.Operation, output_tensor: tf.Tensor, subgraph_to_push: List[tf.Operation] -): - old_tensor_name = output_tensor.name - output_index = old_tensor_name.split(":")[-1] - new_tensor_name = f"{new_get_next_op.name}:{output_index}" - logger.info("Replace old_tensor_name: %s to new_tensor_name: %s", old_tensor_name, new_tensor_name) - new_tensor = graph.get_tensor_by_name(new_tensor_name) - for output_consumer in _get_tensor_consumers_unsafe(output_tensor): - if output_consumer in subgraph_to_push: - logger.info( - "Ignore consumer in old subgraph %s, not let it connect to new IteratorGetNext.", output_consumer - ) - continue - for i, consumer_input in enumerate(output_consumer.inputs): - if consumer_input != output_tensor: - logger.debug("Not replace output_consumer: %s consumer_input: %s.", output_consumer, consumer_input) - continue - logger.info( - "Success replace output_consumer: %s type: %s from consumer_input: %s to new_tensor: %s", - output_consumer.name, - output_consumer.type, - consumer_input, - new_tensor, - ) - output_consumer._update_input(i, new_tensor) - - -def _update_subgraph_out_consumer( - graph: tf.Graph, new_get_next_op: tf.Operation, offset: int, output_tensor: tf.Tensor -): - new_tensor_name = f"{new_get_next_op.name}:{offset}" - logger.info("Replace old_tensor_name: %s to new_tensor_name: %s.", output_tensor.name, new_tensor_name) - new_tensor = graph.get_tensor_by_name(new_tensor_name) - for output_consumer in _get_tensor_consumers_unsafe(output_tensor): - if output_consumer.type in _IGNORE_REPLACE_NODE: - logger.info("Ignore replace output_consumer: %s, it's of type: %s.", output_consumer, output_consumer.type) - continue - for j, consumer_input in enumerate(output_consumer.inputs): - if consumer_input != output_tensor: - logger.debug("Not replace output_consumer: %s consumer_input: %s.", output_consumer, consumer_input) - continue - logger.info( - "Success replace output_consumer: %s type: %s from consumer_input: %s to new_tensor: %s", - output_consumer.name, - output_consumer.type, - consumer_input, - new_tensor, - ) - output_consumer._update_input(j, new_tensor) - - -def _patched_get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapter: - try: - target_op = modifier.get_dataset_op(get_next_op) - except (ValueError, TypeError, RuntimeError) as err: - logger.debug("In `OneShotIterator` mode, find `PrefetchDataset` from all ops in graph.") - graph = tf.compat.v1.get_default_graph() - dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] - dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) - logger.debug("Got sorted dataset_op_list: %s.", dataset_op_list) - if len(dataset_op_list) != 2: - raise RuntimeError( - f"Expect two `PrefetchDataset` ops in dataset_op_list, but got: {dataset_op_list}." - ) from err - target_op = dataset_op_list[1] - except Exception as err: - raise RuntimeError(f"The source dataset can't be found, got error: {err}.") from err - - if not target_op.outputs: - raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") - - logger.debug("Find target dataset op: %s, and output is %s.", target_op, target_op.outputs) - src_dataset = modifier.find_target_instance_dataset(target_op.outputs[0]) - - return src_dataset diff --git a/mx_rec/graph/constants.py b/mx_rec/graph/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6c67b20195dee28f61565201297eacd6b188078b --- /dev/null +++ b/mx_rec/graph/constants.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from enum import Enum + + +class DeprecatedOp(Enum): + DEPRECATED_ITERATOR_GET_NEXT = "DEPRECATED_ITERATOR_GET_NEXT" + DEPRECATED_PREFETCH_DATASET = "DEPRECATED_PREFETCH_DATASET" + + +class AnchorDatasetOp(Enum): + MODEL_DATASET = "ModelDataset" + OPTIMIZE_DATASET = "OptimizeDataset" + PREFETCH_DATASET = "PrefetchDataset" + + +class AnchorIteratorOp(Enum): + ITERATOR_GET_NEXT = "IteratorGetNext" + ITERATOR_V2 = "IteratorV2" + MAKE_ITERATOR = "MakeIterator" + ONE_SHOT_ITERATOR = "OneShotIterator" diff --git a/mx_rec/graph/graph_typing.py b/mx_rec/graph/graph_typing.py deleted file mode 100644 index c11bd4c0edc1f163eeba4e26c88223c34b9727c6..0000000000000000000000000000000000000000 --- a/mx_rec/graph/graph_typing.py +++ /dev/null @@ -1,35 +0,0 @@ -# !/usr/bin/env python3 -# -- coding: utf-8 -- -# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. - -import dataclasses -from typing import Dict, DefaultDict, List, Tuple, Set - -from tensorflow import Operation, Tensor -from tensorflow.core.framework.graph_pb2 import GraphDef - - -# DefaultDict: -# Key: Tensor => Represent output tensor of `IteratorGetNext` operation. -# Val: List[Tuple[int, Operation]] => Contains target operation of output tensor and it's corresponding index. -ReplacementSpec = DefaultDict[Tensor, List[Tuple[int, Operation]]] - - -@dataclasses.dataclass -class AnchorRecord: - replacement_spec: ReplacementSpec - passing_tensors: List[Tensor] - batch_tensor_indexs: List[int] - sub_cutting_points: List[Tensor] - sub_graph_def: GraphDef - input_names: List[str] - output_names: List[str] - is_training: bool - input_indexs: List[int] = None - - -@dataclasses.dataclass -class SubgraphInfo: - subgraph_in: Dict[Operation, Set[Operation]] - subgraph_out: Dict[Operation, Set[Operation]] - subgraph_to_push: Set[Operation] diff --git a/mx_rec/graph/hooks.py b/mx_rec/graph/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..c97ae299ea821550934671c5b10a5d8a218a89ff --- /dev/null +++ b/mx_rec/graph/hooks.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List + +import tensorflow as tf +from tensorflow import Operation, Graph + +from mx_rec.util.log import logger +from mx_rec.graph.slicers import LookupSubgraphSlicer, OrphanLookupKeySlicer +from mx_rec.validator.validator import ClassValidator, para_checker_decorator + + +@para_checker_decorator( + check_option_list=[ + ("op_types", ClassValidator, {"classes": (list)}), + ] +) +class LookupSubgraphSlicerHook(tf.estimator.SessionRunHook): + def __init__(self, op_types: List[Operation]) -> None: + super().__init__() + self._op_types = op_types + + def begin(self) -> None: + slicer = LookupSubgraphSlicer(self._op_types) + + logger.info("Starts to summarize sliceable specific operations in lookup subgraph!") + slicer.summarize() + + logger.info("Starts to slice specific operations and their corresponding minimum dependency graphs!") + slicer.slice() + + +class OrphanLookupKeySlicerHook(tf.estimator.SessionRunHook): + def __init__(self) -> None: + super().__init__() + + def begin(self) -> None: + slicer = OrphanLookupKeySlicer() + + logger.info("Starts to summarize sliceable orphan lookup keys!") + slicer.summarize() + + logger.info("Starts to slice orphan lookup keys and their corresponding minimum dependency graphs!") + slicer.slice() diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 8a11e515b4523bcffbdcbda853247f9fa76ef7e8..0b646cab94b085e6228e28b0f1d5a6059170a407 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -50,7 +50,7 @@ def do_merge_lookup(is_train: bool = True): # get anchor ids cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) if not cutting_point_list: - raise RuntimeError("The sparse table does not have sparse lookup.") + raise RuntimeError("the sparse table does not have sparse lookup.") check_cutting_points(cutting_point_list) # get lookup info @@ -91,7 +91,8 @@ def do_merge_lookup(is_train: bool = True): if not ConfigInitializer.get_instance().use_static: kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict lookup_result = table_instance.lookup_for_feat_spec(feature_spec, send_count, **kwargs) - replace_anchor_vec(cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result) + graph = tf.compat.v1.get_default_graph() + replace_anchor_vec(graph, cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result) logger.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name) # records whether the current mode has been merged or restored lookup diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index a5843e021d09ea4523d9776c02503bbbff666989..972054819de2ae4adb7738dcf8d01d0602e57c6d 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -15,84 +15,529 @@ # limitations under the License. # ============================================================================== +import dataclasses from collections import defaultdict from collections.abc import Callable -from typing import Any, List, Dict, Tuple +from typing import Any, List, Dict, Tuple, DefaultDict import tensorflow as tf -from tensorflow import Operation, Tensor +from tensorflow import Operation, Tensor, Graph from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from tensorflow.python.framework.errors_impl import InvalidArgumentError -from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, MAX_WHILE_SIZE, LIBREC_EOS_OPS_SO, AnchorDatasetOp, \ - AnchorIteratorOp +from mx_rec.graph import utils +from mx_rec.constants.constants import ( + ASCEND_CUTTING_POINT_INITIALIZER, + ASCEND_SPARSE_LOOKUP_ENTRANCE, + ASCAnchorAttr, + ASCEND_TIMESTAMP, + MAX_WHILE_SIZE, + LIBREC_EOS_OPS_SO, +) from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.asc.swap_args import SwapArgs, SwapDataType +from mx_rec.core.asc.build_graph import SwapInfo from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding from mx_rec.graph.merge_lookup import do_merge_lookup -from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ - export_pb_graph, make_sorted_key_to_tensor_list -from mx_rec.graph.graph_typing import AnchorRecord, ReplacementSpec +from mx_rec.graph.utils import check_and_force_list, export_pb_graph +from mx_rec.graph.constants import DeprecatedOp, AnchorDatasetOp, AnchorIteratorOp from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.util.perf import performance +from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.validator.validator import para_checker_decorator, ClassValidator -def get_preprocessing_map_func( +class GraphModifierHook(tf.estimator.SessionRunHook): + @para_checker_decorator( + check_option_list=[ + ("dump_graph", ClassValidator, {"classes": (bool,)}), + ("modify_graph", ClassValidator, {"classes": (bool,)}), + ] + ) + def __init__(self, dump_graph: bool = False, modify_graph: bool = True): + self._dump_graph = dump_graph + self._modify_graph = modify_graph + self._iterator_type = None + + ConfigInitializer.get_instance().train_params_config.is_graph_modify_hook_running = True + + def begin(self): + if self._modify_graph: + modify_graph_and_start_emb_cache(dump_graph=self._dump_graph) + else: + start_asc_pipeline() + + self._iterator_type = ConfigInitializer.get_instance().train_params_config.iterator_type + if self._modify_graph and self._iterator_type not in ( + AnchorIteratorOp.MAKE_ITERATOR.value, + AnchorIteratorOp.ONE_SHOT_ITERATOR.value, + ): + raise ValueError("the value of iterator type should be like `MakeIterator` or `OneShotIterator`.") + logger.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) + + def after_create_session(self, session, coord): + if self._modify_graph and self._iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: + session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) + + +@dataclasses.dataclass +class _AnchorRecord: + replacement_spec: DefaultDict[Tensor, List[Tuple[int, Operation]]] + passing_tensors: List[Tensor] + batch_tensor_indexs: List[int] + sub_cutting_points: List[Tensor] + sub_graph_def: GraphDef + input_names: List[str] + output_names: List[str] + is_training: bool + input_indexs: List[int] = None + + +class _GraphModifier: + @para_checker_decorator( + check_option_list=[ + ("dump_graph", ClassValidator, {"classes": (bool,)}), + ("modify_graph", ClassValidator, {"classes": (bool,)}), + ] + ) + def __init__(self, full_graph: Graph = None, dump_graph: bool = False): + if not full_graph: + full_graph = tf.compat.v1.get_default_graph() + self._full_graph = full_graph + self._dump_graph = dump_graph + + @staticmethod + def _get_preprocessing_map_func( graph_def: GraphDef, input_names: List[str], output_names: List[str], batch_tensor_names: List[str] = None, - pipeline_input_indexes: List[int] = None -) -> Callable: - input_names = check_input_list(input_names, str) - output_names = check_input_list(output_names, str) - batch_tensor_names = check_input_list(batch_tensor_names, str) - pipeline_input_indexes = check_input_list(pipeline_input_indexes, int) - both_is_none = batch_tensor_names is None and pipeline_input_indexes is None - both_not_none = batch_tensor_names is not None and pipeline_input_indexes is not None - if both_is_none or both_not_none: - raise ValueError("It is legal when and only when one of the parameters 'batch_tensor_names' and " - "'pipeline_input_indexes' was given.") - - def map_func(*args): - logger.debug("In get_preprocessing_map_func, the old batch is: %s.", args) - batch = dict() - parse_batch(args, batch, key=None) - logger.debug("In get_preprocessing_map_func, the parse batch is: %s.", batch) - - input_tensors = [] - if batch_tensor_names is not None: - for tensor_name in batch_tensor_names: - tensor = batch.get(tensor_name) - if tensor is None: - raise ValueError(f"Given input_tensor_name '{tensor_name}' is invalid.") - - input_tensors.append(tensor) + pipeline_input_indexes: List[int] = None, + ) -> Callable: + input_names = check_and_force_list(input_names, str) + output_names = check_and_force_list(output_names, str) + batch_tensor_names = check_and_force_list(batch_tensor_names, str) + pipeline_input_indexes = check_and_force_list(pipeline_input_indexes, int) + both_is_none = batch_tensor_names is None and pipeline_input_indexes is None + both_not_none = batch_tensor_names is not None and pipeline_input_indexes is not None + if both_is_none or both_not_none: + raise ValueError( + "It is legal when and only when one of the parameters 'batch_tensor_names' and " + "'pipeline_input_indexes' was given." + ) + + def map_func(*args): + logger.debug("In get_preprocessing_map_func, the old batch is: %s.", args) + batch = dict() + _parse_batch(args, batch, key=None) + logger.debug("In get_preprocessing_map_func, the parse batch is: %s.", batch) + + input_tensors = [] + if batch_tensor_names is not None: + for tensor_name in batch_tensor_names: + tensor = batch.get(tensor_name) + if tensor is None: + raise ValueError(f"Given input_tensor_name '{tensor_name}' is invalid.") + + input_tensors.append(tensor) + + else: + graph = tf.compat.v1.get_default_graph() + for index in pipeline_input_indexes: + tensor = graph.get_tensor_by_name("args_%d:0" % index) + input_tensors.append(tensor) + + # 以tf.import_graph_def()作为read emb key的输入,保证数据读取到传入lookup的ids过程中的特征处理关系能够保留在子图中。 + output_list = tf.import_graph_def( + graph_def, input_map=dict(zip(input_names, input_tensors)), return_elements=output_names + ) + + output_batch = [batch, tuple(output_list)] + logger.debug("In get_preprocessing_map_func, the output batch is: %s.", output_batch) + return tuple(output_batch) + + return map_func + + @performance("graph_modifier") + def modify_graph_for_asc(self, prefetch: int = 10): + cutting_point_list = self._full_graph.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) + utils.check_cutting_points(cutting_point_list) + if not cutting_point_list: + logger.warning("Nothing to revise.") + return + + export_pb_graph("old_graph.pbtxt", self._dump_graph, graph_def=self._full_graph.as_graph_def()) + get_next_op_map = self._generate_get_next_op_specs(cutting_point_list) + logger.debug( + "In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.", + len(get_next_op_map), + get_next_op_map.keys(), + ) + + for get_next_op, record in get_next_op_map.items(): + is_training = record.is_training + + # get source dataset + src_dataset = self._get_src_dataset(get_next_op, is_training) + + # generate target dataset + timestamp_index = _get_timestamp_index(self._full_graph, get_next_op, is_training) + original_batch_tensor_count = _get_dataset_tensor_count(src_dataset) + sub_cutting_points = record.sub_cutting_points + input_index_list = _get_input_index_list( + sub_cutting_points, + record.replacement_spec, + record.output_names, + original_batch_tensor_count, + timestamp_index=timestamp_index, + ) + record.input_indexs = input_index_list + + with self._full_graph.as_default(): + tgt_dataset = self._get_tgt_dataset(src_dataset, sub_cutting_points, record, prefetch=prefetch) + self._update_iterator_getnext(get_next_op, tgt_dataset, is_training, record) + + # In eval mode, backward is not required. In addition, compute gradients is not executed when + # only eval is used. Therefore, `do_merge_lookup` needs to be invoked during modify graph. + if not is_training: + with self._full_graph.as_default(): + do_merge_lookup(is_train=False) + if "evaluate" in ConfigInitializer.get_instance().train_params_config.bool_gauge_set: + logger.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.") + ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_training, False) + # In training mode, `do_merge_lookup` should have been executed in compute gradients phase. + if is_training and not ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(True): + raise RuntimeError( + "In training mode, `do_merge_lookup` should have been executed in compute gradients " + "phase. Please check whether compute gradients is performed." + ) + + self._modify_graph_for_ddr(get_next_op_map) + + logger.info("Graph has been revised.") + export_pb_graph("new_graph.pbtxt", self._dump_graph, graph_def=self._full_graph.as_graph_def()) + + def _modify_graph_for_ddr(self, get_next_op_map: Dict[Tensor, _AnchorRecord]): + # 通过create_hash_optimizer创建optimizer_instance + optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance + # Predict mode + if optimizer_instance is None: + slot_num = 0 + else: + # DDR和扩容需要在获取优化器后重置ext + _change_ext_emb_size_by_opt(optimizer_instance) + slot_num = optimizer_instance.slot_num + + for _, record in get_next_op_map.items(): + is_training = record.is_training + channel_id = 0 if is_training else 1 + swap_args = SwapArgs() + sparse_variables = self._full_graph.get_collection( + ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection + ) + + for each_var in sparse_variables: + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(each_var) + if table_instance.is_hbm: + continue + variable_and_slot_list = _get_variable_and_slot_list( + each_var, slot_num, table_instance.table_name, channel_id + ) + + swap_args_dict = swap_args.swap_config_dict[table_instance.table_name][channel_id] + swap_op = _get_swap_info( + table_instance, variable_and_slot_list, swap_args_dict["swap_info"], channel_id) + # gather for id_offset need to be executed after swap_op + swap_control_dict = swap_args.swap_control_dict[table_instance.table_name][channel_id] + if SwapDataType.CONTROL_OPS.value not in swap_control_dict: + raise ValueError("swap control missing key [control_ops] in modify_graph_for_asc") + control_ops = swap_control_dict[SwapDataType.CONTROL_OPS.value] + utils.replace_anchor_control(self._full_graph, control_ops, swap_op) + + if is_training and slot_num > 1: + # gather for slot need to be executed after swap_op + slot_control_dict = swap_args.slot_control_dict[table_instance.variable] + if SwapDataType.CONTROL_OPS.value not in slot_control_dict: + raise ValueError("slot control missing key [control_ops] in modify_graph_for_asc") + slot_control_ops = slot_control_dict[SwapDataType.CONTROL_OPS.value] + utils.replace_anchor_control(self._full_graph, slot_control_ops, swap_op) + + def _generate_get_next_op_specs(self, cutting_point_list: List[Tensor]) -> Dict[Tensor, _AnchorRecord]: + get_next_op_map = defaultdict(dict) + + for input_tensor in cutting_point_list: + get_next_op = utils.upward_bfs_op(input_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value) + if get_next_op not in get_next_op_map: + logger.debug("find a new get_next_op named '%s'", get_next_op.name) + + replacement_specs = utils.record_ops_to_replace(self._full_graph, get_next_op) + passing_tensors, batch_tensor_indexs, sub_cutting_points = _get_passing_tensor_list( + cutting_point_list, get_next_op + ) + sub_graph_def, input_names, output_names = self._get_sub_graph(passing_tensors, sub_cutting_points) + is_training = BaseSparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING) + + record = _AnchorRecord( + replacement_specs, + passing_tensors, + batch_tensor_indexs, + sub_cutting_points, + sub_graph_def, + input_names, + output_names, + is_training, + ) + get_next_op_map[get_next_op] = record + + export_pb_graph(f"cut_graph_{get_next_op.name}.pbtxt", self._dump_graph, graph_def=sub_graph_def) + + return get_next_op_map + + def _get_sub_graph( + self, input_tensors: List[Tensor], output_tensors: List[Tensor] + ) -> Tuple[GraphDef, List[str], List[str]]: + input_tensors = check_and_force_list(input_tensors, tf.Tensor) + output_tensors = check_and_force_list(output_tensors, tf.Tensor) + input_op_name_list = [tensor.op.name for tensor in input_tensors] + output_op_name_list = [tensor.op.name for tensor in output_tensors] + + graph_def = self._full_graph.as_graph_def() + cut_graph_input = tf.compat.v1.graph_util.extract_sub_graph(graph_def, input_op_name_list) + cut_graph_output = tf.compat.v1.graph_util.extract_sub_graph(graph_def, output_op_name_list) + + node_list = [] + node_list_input = cut_graph_input.node + node_list_output = cut_graph_output.node + for node in node_list_output: + if node not in node_list_input: + node_list.append(node) + + sub_graph_def = tf.compat.v1.GraphDef() + sub_graph_def.node.extend(node_list) + + input_name_list = [tensor.name for tensor in input_tensors] + output_name_list = [tensor.name for tensor in output_tensors] + + return sub_graph_def, input_name_list, output_name_list + + def _get_src_dataset(self, get_next_op: Operation, is_training: bool) -> DatasetV1Adapter: + """ + 根据`IteratorGetNext`算子在计算图中找出原始dataset. + + Args: + get_next_op: `IteratorGetNext`算子 + is_training: 当前是否为训练模式,训练模式为True,否则为False + + Returns: 原始数据集 + + """ + + try: + target_op = utils.find_trans_dataset(self._full_graph, get_next_op) + except (ValueError, TypeError, RuntimeError) as err: + logger.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) + graph = self._full_graph + dataset_op_list = [op for op in graph.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name] + + # WARN: Couple with NoGradSubGraphSlicer::_find_old_dataset. + dataset_op_list = list( + filter( + lambda op: op not in self._full_graph.get_collection(DeprecatedOp.DEPRECATED_PREFETCH_DATASET), + dataset_op_list, + ) + ) + dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) + + logger.debug( + "In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", + is_training, + dataset_op_list, + ) + + if len(dataset_op_list) == 1: + target_op = dataset_op_list[0] + elif is_training and len(dataset_op_list) == 2: + prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) + target_op = prefetch_dataset_op_list[0] + elif not is_training and len(dataset_op_list) == 3: + prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) + target_op = prefetch_dataset_op_list[1] + else: + raise RuntimeError( + f"'{AnchorDatasetOp.PREFETCH_DATASET.value}' not found, got transformation datasets: " + f"{dataset_op_list}." + ) from err + except Exception as err: + raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err + + if not target_op.outputs: + raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") + logger.debug("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) + src_dataset = utils.find_target_instance_dataset(self._full_graph, target_op.outputs[0]) + return src_dataset + + def _get_tgt_dataset( + self, + src_dataset: DatasetV1Adapter, + sub_cutting_point_list: List[Tensor], + record: _AnchorRecord, + prefetch: int = 10, + ) -> DatasetV1Adapter: + """ + 根据原始数据集生成新的数据集实例. + + Args: + src_dataset: 原始数据集实例 + sub_cutting_point_list: 打桩的lookup ids列表 + records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典 + dump_graph: 是否dump计算图,默认为False + prefetch: dataset预取数据量,默认为10 + + Returns: 新数据集实例 + + """ + + librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO) + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id( + record.is_training + ) + # 在数据读取完时,通过EosDataset向acl数据通道发送end_of_sequence + max_train_steps = ConfigInitializer.get_instance().max_steps + max_eval_steps = ConfigInitializer.get_instance().eval_steps + src_dataset = src_dataset.eos_map(librec, channel_id, max_train_steps, max_eval_steps) + + tgt_dataset = src_dataset.map( + self._get_preprocessing_map_func( + record.sub_graph_def, + record.input_names, + record.output_names, + pipeline_input_indexes=record.batch_tensor_indexs, + ) + ) + + feature_numbers = [ + BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt + for cutting_point in sub_cutting_point_list + ] + table_names = [ + BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name + for cutting_point in sub_cutting_point_list + ] + tgt_dataset = tgt_dataset.map( + get_asc_insert_func( + feature_numbers=feature_numbers, + table_names=table_names, + args_index_list=record.input_indexs, + is_training=record.is_training, + dump_graph=self._dump_graph, + ) + ) + + tgt_dataset = tgt_dataset.prefetch(prefetch) + return tgt_dataset + + def _update_iterator_getnext( + self, get_next_op: Operation, tgt_dataset: DatasetV1Adapter, is_training: bool, record: _AnchorRecord + ) -> None: + """ + 用新数据集中的`IteratorGetNext`算子替换计算图中原始数据集的`IteratorGetNext`算子,即用新数据集的batch替换原始数据集的batch. + + Args: + get_next_op: `IteratorGetNext`算子 + tgt_dataset: 新数据集 + is_training: 当前是否为训练模式,训练模式为True,否则为False + records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典 + + Returns: None + + """ + if not get_next_op.outputs: + raise RuntimeError("there is no tensor in the dataset. Please check the dataset and data processing.") + iterator_type = "" + if get_next_op.outputs[0].op.inputs: + iterator_type = get_next_op.outputs[0].op.inputs[0].op.type + if iterator_type == "IteratorV2": + iterator_type = utils.find_make_iterator_op(self._full_graph, get_next_op.outputs[0]).type + if iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value): + raise RuntimeError( + f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " + f"but the current iterator is `{iterator_type}`." + ) + ConfigInitializer.get_instance().train_params_config.iterator_type = iterator_type + logger.info("The iterator type of dataset is `%s`.", iterator_type) + + if iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: + new_iterator = tgt_dataset.make_initializable_iterator() + tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) + ConfigInitializer.get_instance().train_params_config.set_initializer(is_training, new_iterator.initializer) else: - graph = tf.compat.v1.get_default_graph() - for index in pipeline_input_indexes: - tensor = graph.get_tensor_by_name("args_%d:0" % index) - input_tensors.append(tensor) + new_iterator = tgt_dataset.make_one_shot_iterator() + new_batch = new_iterator.get_next() + ConfigInitializer.get_instance().train_params_config.set_target_batch(is_training, new_batch) + + try: + new_batch_tensor = list(new_batch.values())[0] + except IndexError as err: + raise IndexError("Cannot find a tensor from given batch.") from err + new_get_next_op_name = utils.upward_bfs_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value).name + self._update_input_tensor_with_new_batch(record.replacement_spec, new_get_next_op_name, new_batch) + + def _update_input_tensor_with_new_batch( + self, + replacement_specs: DefaultDict[Tensor, List[Tuple[int, Operation]]], + new_get_next_op_name: str, + new_batch: Dict[str, Tensor], + ) -> None: + """ + 用新batch中的IteratorGetNext替换计算图中老batch的IteratorGetNext. + + Args: + replacement_specs: 记录待替换算子的dict,key为老batch的IteratorGetNext,value为以老batch作为输入的算子 + new_get_next_op_name: 新数据集的get_next算子名称 + new_batch: 新数据集的batch + + Returns: None + + """ - # 以tf.import_graph_def()作为read emb key的输入,保证数据读取到传入lookup的ids过程中的特征处理关系能够保留在子图中。 - output_list = tf.import_graph_def(graph_def, input_map=dict(zip(input_names, input_tensors)), - return_elements=output_names) + for old_tensor, item in replacement_specs.items(): + for idx, operator in item: + old_tensor_name = old_tensor.name + output_index = old_tensor_name.split(":")[-1] + new_tensor_name = f"{new_get_next_op_name}:{output_index}" + new_tensor = self._full_graph.get_tensor_by_name(new_tensor_name) + try: + operator._update_input(idx, new_tensor) + except InvalidArgumentError as err: + logger.info( + "The replacement specs keys (old batch) is: %s. \n\t\t The new batch is: %s.", + replacement_specs.keys(), + new_batch, + ) + raise RuntimeError( + f"Cannot update edge, old tensor: {old_tensor}, new tensor: {new_tensor}." + ) from err - output_batch = [batch, tuple(output_list)] - logger.debug("In get_preprocessing_map_func, the output batch is: %s.", output_batch) - return tuple(output_batch) - return map_func +@para_checker_decorator( + check_option_list=[ + ("full_graph", ClassValidator, {"classes": (Graph, type(None))}), + ("dump_graph", ClassValidator, {"classes": (bool,)}), + ] +) +def modify_graph_and_start_emb_cache(full_graph: Graph = None, dump_graph: bool = False): + modifier = _GraphModifier(full_graph=full_graph, dump_graph=dump_graph) + modifier.modify_graph_for_asc() + start_asc_pipeline() -def parse_batch(data_args: Any, data_batch: dict, key: str = None): +def _parse_batch(data_args: Any, data_batch: dict, key: str = None): """ 解析原始数据集中的batch,并将非dict格式的batch转为dict格式. Args: @@ -116,7 +561,7 @@ def parse_batch(data_args: Any, data_batch: dict, key: str = None): """ - if key is not None: + if key: data_batch[key] = data_tensor return @@ -126,11 +571,11 @@ def parse_batch(data_args: Any, data_batch: dict, key: str = None): # 开始解析old batch if isinstance(data_args, dict): for key, data_tensor in data_args.items(): - parse_batch(data_tensor, data_batch, key) + _parse_batch(data_tensor, data_batch, key) return if isinstance(data_args, (list, tuple)): for data_arg in data_args: - parse_batch(data_arg, data_batch, key) + _parse_batch(data_arg, data_batch, key) return if isinstance(data_args, Tensor): # 将old batch中的tensor加入到dict中 @@ -140,12 +585,12 @@ def parse_batch(data_args: Any, data_batch: dict, key: str = None): raise ValueError(f"Invalid batch type, expected: (dict, list, tuple, Tensor), got: {type(data_args)}.") -def get_input_index_list( - cutting_point_list: List[Tensor], - replacement_specs: ReplacementSpec, - mapping_name_list: List[str], - base_count: int, - timestamp_index: int = None +def _get_input_index_list( + cutting_point_list: List[Tensor], + replacement_specs: DefaultDict[Tensor, List[Tuple[int, Operation]]], + mapping_name_list: List[str], + base_count: int, + timestamp_index: int = None, ) -> List[int]: input_index_list = [] for cutting_point in cutting_point_list: @@ -164,78 +609,8 @@ def get_input_index_list( return input_index_list -def find_make_iterator_op(batch_tensor: Tensor) -> Operation: - graph = tf.compat.v1.get_default_graph() - operations = graph.get_operations() - for each_op in operations: - for input_tensor in batch_tensor.op.inputs: - if input_tensor.op.outputs and input_tensor.op.outputs[0] in list( - each_op.inputs) and each_op.type == AnchorIteratorOp.MAKE_ITERATOR.value: - logger.debug("Op MakeIterator '%s' was found.", each_op.name) - return each_op - - raise ValueError(f"Op MakeIterator was not found.") - - -@performance("find_target_dataset_op") -def find_target_dataset_op(base_ops: Operation, op_type: str) -> Operation: - base_ops = check_input_list(base_ops, tf.Operation) - parent_ops = base_ops - - while_num = 0 - while True: - while_num += 1 - if while_num > MAX_WHILE_SIZE: - raise RuntimeError(f"In find_target_dataset_op function, the maximum cycle depth is greater " - f"than {MAX_WHILE_SIZE}.") - for parent_op in parent_ops: - if parent_op.type == op_type: - return parent_op - - base_ops = parent_ops - parent_ops = [] - for base_op in base_ops: - parent_ops.extend(find_parent_op(base_op)) - - if not parent_ops: - raise ValueError(f"Op {op_type} was not found.") - - -def get_dataset_op(get_next_op: Operation) -> Operation: - """ - 根据`IteratorGetNext`算子从图中找到`OptimizeDataset`的dataset op. - 注: TF2没有`OptimizeDataset`,则找的是dataset的默认锚点. - - Args: - get_next_op: `IteratorGetNext`算子 - - Returns: TF1返回`OptimizeDataset`算子,TF2返回dataset默认锚点的算子 - - """ - - if get_next_op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: - raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") - - # looking for the MakeIterator operator which corresponds to given batch_tensor - base_op = find_make_iterator_op(get_next_op.outputs[0]) - # looking for the op which is the one before OptimizeDataset operator - if tf.__version__.startswith("1"): - optimize_dataset_op = find_target_dataset_op(base_op, AnchorDatasetOp.MODEL_DATASET.value) - target_op = find_parent_op(optimize_dataset_op) - if not target_op: - raise RuntimeError(f"The parent op for 'ModelDataset' op was not found.") - if target_op[0].type != AnchorDatasetOp.OPTIMIZE_DATASET.value: - raise TypeError(f"Op OptimizeDataset was not found.") - target_op = target_op[0] - else: - # 'OptimizeDataset' is not available in TensorFlow2.X - target_op = find_target_dataset_op(base_op, AnchorDatasetOp.PREFETCH_DATASET.value) - return target_op - - -def get_passing_tensor_list( - src_tensors: List[Tensor], - target_op: Operation +def _get_passing_tensor_list( + src_tensors: List[Tensor], target_op: Operation ) -> Tuple[List[Tensor], List[int], List[Tensor]]: def get_passing_tensors(src_tensor): passing_tensors = [] @@ -244,8 +619,9 @@ def get_passing_tensor_list( while tensor_list: while_num += 1 if while_num > MAX_WHILE_SIZE: - raise RuntimeError(f"In get_passing_tensors function, the maximum cycle depth is greater " - f"than {MAX_WHILE_SIZE}.") + raise RuntimeError( + f"In get_passing_tensors function, the maximum cycle depth is greater " f"than {MAX_WHILE_SIZE}." + ) last_tensor = tensor_list.pop() if last_tensor.op is target_op: passing_tensors.append(last_tensor) @@ -254,7 +630,7 @@ def get_passing_tensor_list( return passing_tensors - src_tensors = check_input_list(src_tensors, Tensor) + src_tensors = check_and_force_list(src_tensors, Tensor) passing_tensor_list = [] sub_src_tensors = [] for tensor in src_tensors: @@ -273,83 +649,7 @@ def get_passing_tensor_list( return passing_tensor_list, output_index_list, sub_src_tensors -def find_target_instance_dataset(variant_tensor: Tensor) -> DatasetV1Adapter: - dataset_instance_list = tf.compat.v1.get_collection("dataset_group") - for ins in dataset_instance_list: - if ins._variant_tensor == variant_tensor: - if not isinstance(ins, DatasetV1Adapter): - ins = ins._input_dataset - logger.debug("Find target instance '%s', whose variant_tensor is '%s'.", ins, variant_tensor) - if not isinstance(ins.element_spec, dict) and not ( - isinstance(ins.element_spec, (list, tuple)) and len(ins.element_spec) == 2 and isinstance( - ins.element_spec[0], dict)): - raise NotImplementedError("The found dataset does not return a valid layout.") - - return ins - - raise LookupError(f"Can not find target instance, whose variant_tensor is '{variant_tensor}' respectively.") - - -def get_sub_graph( - input_tensors: List[Tensor], - output_tensors: List[Tensor] -) -> Tuple[GraphDef, List[str], List[str]]: - input_tensors = check_input_list(input_tensors, tf.Tensor) - output_tensors = check_input_list(output_tensors, tf.Tensor) - input_op_name_list = [tensor.op.name for tensor in input_tensors] - output_op_name_list = [tensor.op.name for tensor in output_tensors] - - graph_def = tf.compat.v1.get_default_graph().as_graph_def() - cut_graph_input = tf.compat.v1.graph_util.extract_sub_graph(graph_def, input_op_name_list) - cut_graph_output = tf.compat.v1.graph_util.extract_sub_graph(graph_def, output_op_name_list) - - node_list = [] - node_list_input = cut_graph_input.node - node_list_output = cut_graph_output.node - for node in node_list_output: - if node not in node_list_input: - node_list.append(node) - - sub_graph_def = tf.compat.v1.GraphDef() - sub_graph_def.node.extend(node_list) - - input_name_list = [tensor.name for tensor in input_tensors] - output_name_list = [tensor.name for tensor in output_tensors] - - return sub_graph_def, input_name_list, output_name_list - - -def update_input_tensor_with_new_batch(replacement_specs: ReplacementSpec, - new_get_next_op_name: str, - new_batch: Dict[str, Tensor]): - """ - 用新batch中的IteratorGetNext替换计算图中老batch的IteratorGetNext. - - Args: - replacement_specs: 记录待替换算子的dict,key为老batch的IteratorGetNext,value为以老batch作为输入的算子 - new_get_next_op_name: 新数据集的get_next算子名称 - new_batch: 新数据集的batch - - Returns: None - - """ - - graph = tf.compat.v1.get_default_graph() - for old_tensor, item in replacement_specs.items(): - for idx, operator in item: - old_tensor_name = old_tensor.name - output_index = old_tensor_name.split(":")[-1] - new_tensor_name = f"{new_get_next_op_name}:{output_index}" - new_tensor = graph.get_tensor_by_name(new_tensor_name) - try: - operator._update_input(idx, new_tensor) - except InvalidArgumentError as err: - logger.info("The replacement specs keys (old batch) is: %s. \n\t\t The new batch is: %s.", - replacement_specs.keys(), new_batch) - raise RuntimeError(f"Cannot update edge, old tensor: {old_tensor}, new tensor: {new_tensor}.") from err - - -def get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: +def _get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: """ 获取数据集中batch的tensor数量. @@ -363,241 +663,13 @@ def get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: src_element_spec = dataset.element_spec if not isinstance(src_element_spec, (list, tuple)): src_element_spec = [src_element_spec] - src_sorted_keys = make_sorted_key_to_tensor_list(src_element_spec, []) + src_sorted_keys = utils.make_sorted_key_to_tensor_list(src_element_spec, []) return len(src_sorted_keys) -@para_checker_decorator( - check_option_list=[("dump_graph", ClassValidator, {"classes": (bool,)})] -) -def modify_graph_and_start_emb_cache(dump_graph: bool = False): - modify_graph_for_asc(dump_graph=dump_graph) - start_asc_pipeline() - - -def generate_get_next_op_specs( - cutting_point_list: List[Tensor], - dump_graph: bool = False -) -> Dict[Tensor, AnchorRecord]: - get_next_op_map = defaultdict(dict) - - for input_tensor in cutting_point_list: - get_next_op = find_target_dataset_op(input_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value) - if get_next_op not in get_next_op_map: - logger.debug("find a new get_next_op named '%s'", get_next_op.name) - - replacement_specs = record_ops_to_replace(get_next_op) - passing_tensors, batch_tensor_indexs, sub_cutting_points = \ - get_passing_tensor_list(cutting_point_list, get_next_op) - sub_graph_def, input_names, output_names = get_sub_graph(passing_tensors, sub_cutting_points) - is_training = BaseSparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING) - - record = AnchorRecord( - replacement_specs, - passing_tensors, - batch_tensor_indexs, - sub_cutting_points, - sub_graph_def, - input_names, - output_names, - is_training - ) - get_next_op_map[get_next_op] = record - - export_pb_graph(f"cut_graph_{get_next_op.name}.pb", dump_graph, graph_def=sub_graph_def) - - return get_next_op_map - - -def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapter: - """ - 根据`IteratorGetNext`算子在计算图中找出原始dataset. - - Args: - get_next_op: `IteratorGetNext`算子 - is_training: 当前是否为训练模式,训练模式为True,否则为False - - Returns: 原始数据集 - - """ - - try: - target_op = get_dataset_op(get_next_op) - except (ValueError, TypeError, RuntimeError) as err: - logger.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) - graph = tf.compat.v1.get_default_graph() - dataset_op_list = [op for op in graph.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name] - logger.debug("In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", - is_training, dataset_op_list) - - if len(dataset_op_list) == 1: - target_op = dataset_op_list[0] - elif is_training and len(dataset_op_list) == 2: - prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) - target_op = prefetch_dataset_op_list[0] - elif not is_training and len(dataset_op_list) == 3: - prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) - target_op = prefetch_dataset_op_list[1] - else: - raise RuntimeError(f"`{AnchorDatasetOp.PREFETCH_DATASET.value}` not found, got dataset_op_list: " - f"{dataset_op_list}.") from err - except Exception as err: - raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err - - if not target_op.outputs: - raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") - logger.debug("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) - src_dataset = find_target_instance_dataset(target_op.outputs[0]) - return src_dataset - - -def get_tgt_dataset( - src_dataset: DatasetV1Adapter, - sub_cutting_point_list: List[Tensor], - record: AnchorRecord, - dump_graph: bool = False, - prefetch: int = 10 -) -> DatasetV1Adapter: - """ - 根据原始数据集生成新的数据集实例. - - Args: - src_dataset: 原始数据集实例 - sub_cutting_point_list: 打桩的lookup ids列表 - records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典 - dump_graph: 是否dump计算图,默认为False - prefetch: dataset预取数据量,默认为10 - - Returns: 新数据集实例 - - """ - - librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO) - channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id( - record.is_training) - # 在数据读取完时,通过EosDataset向acl数据通道发送end_of_sequence - max_train_steps = ConfigInitializer.get_instance().max_steps - max_eval_steps = ConfigInitializer.get_instance().eval_steps - src_dataset = src_dataset.eos_map(librec, channel_id, max_train_steps, max_eval_steps) - - tgt_dataset = src_dataset.map(get_preprocessing_map_func(record.sub_graph_def, - record.input_names, - record.output_names, - pipeline_input_indexes=record.batch_tensor_indexs)) - - feature_numbers = [BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt for - cutting_point in sub_cutting_point_list] - table_names = [BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for - cutting_point in sub_cutting_point_list] - tgt_dataset = tgt_dataset.map(get_asc_insert_func(feature_numbers=feature_numbers, - table_names=table_names, - args_index_list=record.input_indexs, - is_training=record.is_training, - dump_graph=dump_graph)) - - tgt_dataset = tgt_dataset.prefetch(prefetch) - return tgt_dataset - - -def update_iterator_getnext(get_next_op: Operation, - tgt_dataset: DatasetV1Adapter, - is_training: bool, - record: AnchorRecord): - """ - 用新数据集中的`IteratorGetNext`算子替换计算图中原始数据集的`IteratorGetNext`算子,即用新数据集的batch替换原始数据集的batch. - - Args: - get_next_op: `IteratorGetNext`算子 - tgt_dataset: 新数据集 - is_training: 当前是否为训练模式,训练模式为True,否则为False - records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典 - - Returns: None - - """ - if not get_next_op.outputs: - raise RuntimeError("There is no tensor in the dataset. Please check the dataset and data processing.") - iterator_type = "" - if get_next_op.outputs[0].op.inputs: - iterator_type = get_next_op.outputs[0].op.inputs[0].op.type - if iterator_type == "IteratorV2": - iterator_type = find_make_iterator_op(get_next_op.outputs[0]).type - if iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value): - raise RuntimeError(f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " - f"but the current iterator is `{iterator_type}`.") - ConfigInitializer.get_instance().train_params_config.iterator_type = iterator_type - logger.info("The iterator type of dataset is `%s`.", iterator_type) - - if iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: - new_iterator = tgt_dataset.make_initializable_iterator() - tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) - ConfigInitializer.get_instance().train_params_config.set_initializer(is_training, new_iterator.initializer) - else: - new_iterator = tgt_dataset.make_one_shot_iterator() - new_batch = new_iterator.get_next() - ConfigInitializer.get_instance().train_params_config.set_target_batch(is_training, new_batch) - - try: - new_batch_tensor = list(new_batch.values())[0] - except IndexError as err: - raise IndexError("Cannot find a tensor from given batch.") from err - new_get_next_op_name = find_target_dataset_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value).name - update_input_tensor_with_new_batch(record.replacement_spec, new_get_next_op_name, new_batch) - - -@performance("graph_modifier") -def modify_graph_for_asc(dump_graph: bool = False, prefetch: int = 10): - cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) - check_cutting_points(cutting_point_list) - if not cutting_point_list: - logger.warning("Nothing to revise.") - return - - export_pb_graph("old_graph.pb", dump_graph) - get_next_op_map = generate_get_next_op_specs(cutting_point_list, dump_graph) - logger.debug("In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.", - len(get_next_op_map), get_next_op_map.keys()) - - for get_next_op, record in get_next_op_map.items(): - is_training = record.is_training - - # get source dataset - src_dataset = get_src_dataset(get_next_op, is_training) - - # generate target dataset - timestamp_index = get_timestamp_index(get_next_op, is_training) - original_batch_tensor_count = get_dataset_tensor_count(src_dataset) - sub_cutting_points = record.sub_cutting_points - input_index_list = get_input_index_list(sub_cutting_points, - record.replacement_spec, - record.output_names, - original_batch_tensor_count, timestamp_index=timestamp_index) - record.input_indexs = input_index_list - tgt_dataset = get_tgt_dataset(src_dataset, sub_cutting_points, record, - dump_graph=dump_graph, prefetch=prefetch) - - # update the batch of dataset - update_iterator_getnext(get_next_op, tgt_dataset, is_training, record) - - # In eval mode, backward is not required. In addition, compute gradients is not executed when - # only eval is used. Therefore, `do_merge_lookup` needs to be invoked during modify graph. - if not is_training: - do_merge_lookup(is_train=False) - if 'evaluate' in ConfigInitializer.get_instance().train_params_config.bool_gauge_set: - logger.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.") - ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_training, False) - # In training mode, `do_merge_lookup` should have been executed in compute gradients phase. - if is_training and not ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(True): - raise RuntimeError("In training mode, `do_merge_lookup` should have been executed in compute gradients " - "phase. Please check whether compute gradients is performed.") - - logger.info("Graph has been revised.") - export_pb_graph("new_graph.pb", dump_graph) - - -def get_timestamp_index(get_next_op: Operation, is_training: bool) -> int: - timestamp_tensor_list = tf.compat.v1.get_collection(ASCEND_TIMESTAMP) +def _get_timestamp_index(graph: Graph, get_next_op: Operation, is_training: bool) -> int: + timestamp_tensor_list = graph.get_collection(ASCEND_TIMESTAMP) timestamp_index = None for timestamp in timestamp_tensor_list: if timestamp in get_next_op.outputs: @@ -606,43 +678,107 @@ def get_timestamp_index(get_next_op: Operation, is_training: bool) -> int: if timestamp_feature_spec is None: timestamp_feature_spec = FeatureSpec("timestamp", index_key=timestamp_index, is_timestamp=True) timestamp_feature_spec.include_timestamp(is_training) - ConfigInitializer.get_instance().feature_spec_config.insert_feature_spec(timestamp_feature_spec, - is_training) + ConfigInitializer.get_instance().feature_spec_config.insert_feature_spec( + timestamp_feature_spec, is_training + ) break if timestamp_feature_spec.index_key != timestamp_index: - raise ValueError(f"Given timestamp_index, which is {timestamp_index}, does not match index " - f"key. Please double check.") + raise ValueError( + f"Given timestamp_index, which is {timestamp_index}, does not match index " + f"key. Please double check." + ) timestamp_feature_spec.include_timestamp(is_training) break return timestamp_index -class GraphModifierHook(tf.estimator.SessionRunHook): - @para_checker_decorator( - check_option_list=[ - ("dump_graph", ClassValidator, {"classes": (bool,)}), - ("modify_graph", ClassValidator, {"classes": (bool,)}) - ] - ) - def __init__(self, dump_graph=False, modify_graph=True): - self._dump_graph = dump_graph - self._modify_graph = modify_graph - self._iterator_type = "" - ConfigInitializer.get_instance().train_params_config.is_graph_modify_hook_running = True +def _change_ext_emb_size_by_opt(optimizer): + for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): + # When dynamic expansion mode, ext_emb_size is set by optimizer + if ConfigInitializer.get_instance().use_dynamic_expansion or not table_instance.is_hbm: + table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num) + logger.info("ext_emb_size is reset to be %s in change_ext_emb_size_by_opt", table_instance.ext_emb_size) + + +def _get_variable_and_slot_list(each_var, slot_num, table_name, channel_id): + variable_and_slot_list = [each_var] + if slot_num == 0: + return variable_and_slot_list + + # 通过apply_gradients创建optimizer + optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(table_name) + if optimizer is None and channel_id == 0: + raise RuntimeError( + "In training mode, table_instance should have been set_optimizer_for_table " + "before modify_graph, please check whether apply_gradients is performed" + ) + + # predict不需要传优化器,但是如果客户创建了优化器,ddr模式加载的是维度ext_size的emb用作换入换出,所以需要给slot零值占位 + if optimizer is None and channel_id == 1: + slot_place_holder = tf.zeros_like(each_var) + for _ in range(slot_num): + variable_and_slot_list.append(slot_place_holder) + else: + # opt name to slot dict + for slot_dict in optimizer.values(): + for slot_val in slot_dict.values(): + variable_and_slot_list.append(slot_val) - def begin(self): - if self._modify_graph: - modify_graph_and_start_emb_cache(dump_graph=self._dump_graph) - else: - start_asc_pipeline() + return variable_and_slot_list - self._iterator_type = ConfigInitializer.get_instance().train_params_config.iterator_type - if self._modify_graph and self._iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, - AnchorIteratorOp.ONE_SHOT_ITERATOR.value): - raise ValueError("The value of iterator type should be like `MakeIterator` or `OneShotIterator`.") - logger.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) - def after_create_session(self, session, coord): - if self._modify_graph and self._iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: - session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) +def _get_swap_info(table_instance: BaseSparseEmbedding, variable_and_slot_list: list, + swap_info: SwapInfo, channel_id: int) -> list: + """ + Get swap op. + :param table_instance: BaseSparseEmbedding + :param variable_and_slot_list: [var + slots] + :param swap_info: swap in/out length and position + :param channel_id: train or predict + :return: swap op + """ + if table_instance.is_hbm: + return [tf.no_op()] + + if len(variable_and_slot_list) == 0: + raise RuntimeError("When enable emb_transfer, optimizer should have slots") + + use_static = ConfigInitializer.get_instance().use_static + max_lookup_vec_size = None + if use_static: + max_lookup_vec_size = table_instance.send_count * table_instance.rank_size + + with tf.compat.v1.variable_scope("h2d_emb"): + logger.debug('Channel %s_h2d_%s was built for getnext', table_instance.table_name, channel_id) + h2d_emb = npu_ops.gen_npu_ops.get_next( + output_types=[tf.float32], + output_shapes=[[max_lookup_vec_size, table_instance.ext_emb_size]], + channel_name=f'{table_instance.table_name}_h2d_all')[0] + logger.debug("h2d_emb shape: %s", h2d_emb) + + swap_out_pos = swap_info.swap_out_pos + swap_in_pos = swap_info.swap_in_pos + if use_static: + swap_out_pos = swap_out_pos[:swap_info.swap_out_len] + h2d_emb = h2d_emb[:swap_info.swap_in_len, :] + swap_in_pos = swap_in_pos[:swap_info.swap_in_len] + swap_outs = [tf.gather(one_table, swap_out_pos) for one_table in variable_and_slot_list] + swap_out = tf.concat(swap_outs, axis=1) + logger.debug('Channel %s_d2h_all was built for op outfeed.', table_instance.table_name) + + swap_out_op = npu_ops.outfeed_enqueue_op( + channel_name=f'{table_instance.table_name}_d2h_all', inputs=[swap_out]) + with tf.control_dependencies([swap_out_op]): + nd_swap_pos = tf.expand_dims(swap_in_pos, 1) + var_num = len(variable_and_slot_list) + h2d_emb_split = tf.split(h2d_emb, var_num, axis=1) + + optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name( + table_instance.table_name) + if optimizer is None and channel_id == 1: + swap_in_op = [tf.compat.v1.scatter_nd_update(variable_and_slot_list[0], nd_swap_pos, h2d_emb_split[0])] + else: + swap_in_op = [tf.compat.v1.scatter_nd_update(variable_and_slot_list[i], nd_swap_pos, h2d_emb_split[i]) + for i in range(var_num)] + return swap_in_op diff --git a/mx_rec/graph/slicers.py b/mx_rec/graph/slicers.py new file mode 100644 index 0000000000000000000000000000000000000000..c86e60f1b69813d8a5021af6deb3c291e434d6cb --- /dev/null +++ b/mx_rec/graph/slicers.py @@ -0,0 +1,896 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import abc +from typing import List, Dict, Set, Tuple, Union + +import pandas as pd +import tensorflow as tf +from tensorflow import Operation, Tensor, SparseTensor, Graph, variant, resource +from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter + +from mx_rec.graph import utils +from mx_rec.util.log import logger +from mx_rec.validator.validator import ClassValidator, para_checker_decorator +from mx_rec.constants.constants import ( + ASCAnchorAttr, + ASCEND_TIMESTAMP, + MAX_WHILE_SIZE, + ASCEND_SPARSE_LOOKUP_ENTRANCE, + ORPHAN_LOOKUP_KEY_PREFIX +) +from mx_rec.graph.constants import DeprecatedOp, AnchorDatasetOp, AnchorIteratorOp +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding + + +class NoGradSubgraphSlicer(metaclass=abc.ABCMeta): + _SLICED_OP_NAME_PREFIX = "sliced" + + _SLICING_SUMMARY_NAME = "slicing_summary.csv" + _UNSLICED_FULL_GRAPH_NAME = "unsliced_full_graph.pbtxt" + _SLICED_SUB_GRAPH_NAME = "sliced_sub_graph.pbtxt" + _SLICED_FULL_GRAPH_NAME = "sliced_full_graph.pbtxt" + + _INVALID_STR_IN_OP_TYPE = ("Dataset", "Summary") + _INVALID_STR_IN_OP_NAME = ("save", "report_", "loss") + _INVALID_CONSUMER_OP_TYPE = ("Assign", "SaveV2") + + _VALID_TENSOR_CLASS = (Tensor, SparseTensor) + _INVALID_TENSOR_DTYPE = (variant, resource) + + def __init__(self, full_graph: Graph = None, info_dir: str = "slicing") -> None: + if not full_graph: + full_graph = tf.compat.v1.get_default_graph() + self._full_graph = full_graph + + if not os.path.exists(info_dir): + os.makedirs(info_dir) + self._info_dir = info_dir + + @staticmethod + def _find_min_dep_ops( + tgt_ops: Set[Operation], + ) -> Set[Operation]: + logger.debug("Search from base nodes: %s.", tgt_ops) + base_ops = tgt_ops.copy() + visited_ops = base_ops + + loop_cnt = 0 + while base_ops: + loop_cnt += 1 + if loop_cnt > MAX_WHILE_SIZE: + raise RuntimeError(f"maximum loop times exceed limit: {MAX_WHILE_SIZE}.") + + parent_ops = set() + for base_node in base_ops: + if len(base_node.control_inputs) != 0: + raise ValueError("control dependencies are not supported.") + + parent_ops.update( + tensor_in.op + for tensor_in in base_node.inputs + if tensor_in.op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value + ) + + new_ops = parent_ops - visited_ops + base_ops = parent_ops + visited_ops.update(new_ops) + + logger.debug("Found minimum dependency graph nodes: %s.", visited_ops) + return visited_ops + + @staticmethod + def _validate_op(op: Operation) -> bool: + op_type = op.type + op_name = op.name + op_inputs = op.inputs + op_outputs = op.outputs + + for s in NoGradSubgraphSlicer._INVALID_STR_IN_OP_TYPE: + if s in op_type: + logger.warning("Invalid operation type: %s which contains str: %s.", op_type, s) + return False + for s in NoGradSubgraphSlicer._INVALID_STR_IN_OP_NAME: + if s in op_name: + logger.warning("Invalid operation name: %s which contains str: %s.", op_name, s) + return False + for t in op_inputs: + if t.dtype in NoGradSubgraphSlicer._INVALID_TENSOR_DTYPE: + logger.warning("Invalid operation input tensor of operation: %s whose type is %s.", t, t.dtype) + return False + for t in op_outputs: + if t.dtype in NoGradSubgraphSlicer._INVALID_TENSOR_DTYPE: + logger.warning("Invalid operation output tensor of operation: %s whose type is %s.", t, t.dtype) + return False + + return True + + @staticmethod + def _update_subgraph_in( + base_ops: Operation, + input_to_edge_ops: Dict[Operation, Set[Operation]], + sub_graph_ops: Set[Operation], + ) -> None: + for input_tensor in base_ops.inputs: + input_node = input_tensor.op + if input_node not in sub_graph_ops: + res = input_to_edge_ops.get(input_node, set()) + res.add(base_ops) + input_to_edge_ops[input_node] = res + + @staticmethod + def _update_subgraph_out( + base_ops: Operation, + out_op_to_edge_ops: Dict[Operation, Set[Operation]], + sub_graph_ops: Set[Operation], + ) -> None: + for output_tensor in base_ops.outputs: + for output_consumer in output_tensor.consumers(): + if output_consumer not in sub_graph_ops: + res = out_op_to_edge_ops.get(output_consumer, set()) + res.add(base_ops) + out_op_to_edge_ops[output_consumer] = res + + + @staticmethod + def _topo_sort_sliced_ops(sliced_ops: Set[Operation]) -> List[Operation]: + topo_subgraph_list = [] + topo_subgraph_set = set() + start_nodes = set() + [start_nodes.add(x) for x in sliced_ops] + logger.info("Got topo_subgraph start nodes: %s", start_nodes) + + def topo_sort_helper(curr_op, output_list, output_set): + if not isinstance(curr_op, Operation): + raise RuntimeError(f"topo_subgraph_dfs input should be node(aka. tf.Operator). {curr_op}") + curr_inputs = curr_op.inputs + logger.debug("Got topo_dfs: %s <- %s", curr_op.name, [x.name for x in curr_inputs]) + current_control_inputs = curr_op.control_inputs + if len(current_control_inputs) > 0: + raise RuntimeError( + f"control input are not supported: {curr_op.name}, control_inputs: {current_control_inputs}" + ) + if curr_op in output_set: + return + output_set.add(curr_op) + for tensor in curr_inputs: + node = tensor.op + if node.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value and node not in output_set: + topo_sort_helper(node, output_list, output_set) + output_list.append(curr_op) + + [topo_sort_helper(x, topo_subgraph_list, topo_subgraph_set) for x in start_nodes] + if len(topo_subgraph_list) != len(topo_subgraph_set): + raise RuntimeError(f"got duplicated topo node: {sorted(topo_subgraph_list, key=lambda x: x.name)}.") + logger.info("Got topo_subgraph: %s", topo_subgraph_list) + return topo_subgraph_list + + @staticmethod + def _get_mapping_for_subgraph_in( + from_op: Operation, + to_ops: Set[Operation], + tensor_mapping: Union[Dict[Tensor, Tensor], Dict[SparseTensor, SparseTensor]], + ) -> None: + if from_op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: + raise RuntimeError(f"expect IteratorGetNext for input tensor of subgraph, but got {from_op}") + for node in to_ops: + for each_tensor in node.inputs: + if each_tensor.op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: + continue + old_tensor_name = each_tensor.name + x_index = int(old_tensor_name.split(":")[-1]) + g = tf.compat.v1.get_default_graph() + arg_tensor = g.get_tensor_by_name("args_%d:0" % x_index) + tensor_mapping[each_tensor] = arg_tensor + + @staticmethod + def _get_mapping_for_subgraph( + old_op: Operation, + node_mapping: Dict[Operation, Operation], + tensor_mapping: Dict[Tensor, Tensor], + ) -> None: + logger.debug("old operation name: %s\nold operation inputs: %s\n", old_op.name, [x for x in old_op.inputs]) + + for each_tensor in old_op.inputs: + if each_tensor not in tensor_mapping: + raise RuntimeError( + f"each_tensor(input) {each_tensor} need by {old_op.name} not in tensor_mapping.{tensor_mapping}" + ) + new_inputs = NoGradSubgraphSlicer._get_mapped_tensor(tensor_mapping, old_op.inputs) + + node_def = old_op.node_def + node_def.name = "{}/{}".format(NoGradSubgraphSlicer._SLICED_OP_NAME_PREFIX, node_def.name) + new_node = tf.Operation(node_def=node_def, g=tf.compat.v1.get_default_graph(), inputs=new_inputs) + + node_mapping[old_op] = new_node + for old_out_tensor, new_out_tensor in zip(old_op.outputs, new_node.outputs): + tensor_mapping[old_out_tensor] = new_out_tensor + + @staticmethod + def _get_mapped_tensor(tensor2tensor: Dict[Tensor, Tensor], keys: List[Tensor]) -> List[Tensor]: + tensors = [] + for k in keys: + if k not in tensor2tensor: + raise KeyError(f"failed to find key tensor: {k} from tensor map: {tensor2tensor}.") + tensors.append(tensor2tensor[k]) + return tensors + + @staticmethod + def _sort_sliced_graph_outputs(subgraph_out: Dict[Operation, Set[Operation]]) -> List[Tensor]: + extra_outputs = [] + sorted_outputs = sorted(subgraph_out.items(), key=lambda x: x[0].name) + for outside_op, edge_ops in sorted_outputs: + outside_op_inputs = set(outside_op.inputs) + for edge_op in edge_ops: + NoGradSubgraphSlicer._add_sorted_additional_tensors(extra_outputs, outside_op_inputs, edge_op) + return extra_outputs + + @staticmethod + def _add_sorted_additional_tensors(extra_outputs, outside_op_inputs, edge_op) -> None: + for each_tensor in sorted(edge_op.outputs, key=lambda x: x.name): + if each_tensor not in outside_op_inputs: + continue + if each_tensor in extra_outputs: + continue + extra_outputs.append(each_tensor) + + @staticmethod + def _get_tensor_consumers(tensor: Tensor) -> List[Operation]: + if not isinstance(tensor, NoGradSubgraphSlicer._VALID_TENSOR_CLASS): + raise RuntimeError(f"expected 'tf.Tensor' or 'tf.SparseTensor', but got: {tensor}") + + graph = tensor.graph + consumers = [] + consumer_names = [op.name for op in tensor.consumers()] + + with graph._lock: + for name in consumer_names: + if name not in graph._nodes_by_name: # ignore deleted node + continue + consumers.append(graph._nodes_by_name[name]) + + return consumers + + @abc.abstractmethod + def summarize(self) -> None: + pass + + @abc.abstractmethod + def slice(self) -> None: + pass + + def _slice_ops(self, sliceable_ops: Set[Operation], is_training: bool) -> None: + """Slice the minimum dependency graph of given operation set. + + Args: + sliceable_ops (Set[Operation]): The operation set that can be sliced. + is_training (bool): Whether the slicing is for training graph or not. + """ + + sliced_ops = self._find_min_dep_ops(sliceable_ops) + in_op_to_edge_ops, out_op_to_edge_ops = self._find_subgraph_in_and_out(sliced_ops) + + old_get_next = self._find_old_get_next(sliceable_ops) + old_dataset = self._find_old_dataset(old_get_next, is_training) + + new_dataset = self._make_new_dataset(old_dataset, sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops) + new_dataset = new_dataset.prefetch(0) + + new_get_next = self._make_new_get_next(old_get_next, new_dataset) + self._replace_get_next(old_get_next, new_get_next, out_op_to_edge_ops, sliced_ops) + + def _make_new_dataset( + self, + old_dataset: DatasetV1Adapter, + sliced_ops: Set[Operation], + in_op_to_edge_ops: Dict[Operation, Set[Operation]], + out_op_to_edge_ops: Dict[Operation, Set[Operation]], + ) -> DatasetV1Adapter: + """Make a new dataset which clones the sliced subgraph by mapfunc. + + Args: + old_dataset: The old dataset that needs to be mapped. + sliced_ops: The operation set that has been sliced. + in_op_to_edge_ops: The input relationship of sliced subgraph. + out_op_to_edge_ops: The output relationship of sliced subgraph. + + Returns: + DatasetV1Adapter: The new dataset that has cloned the sliced subgraph. + """ + + def slice_map_func(*batch): # pragma: no cover + logger.debug("The layout of old batch: %s.", batch) + + funcgraph = tf.compat.v1.get_default_graph() + flatten_batch = tf.nest.flatten(batch) + + for t in flatten_batch: + if isinstance(t, NoGradSubgraphSlicer._VALID_TENSOR_CLASS): + continue + raise RuntimeError(f"expected 'tf.Tensor' or 'tf.SparseTensor' in batch, but got %s.", t) + + new_batch = self._clone_subgraph_into_funcgraph(sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops, batch) + utils.export_pb_graph( + file_name=NoGradSubgraphSlicer._SLICED_SUB_GRAPH_NAME, + dump_graph=True, + graph_def=funcgraph.as_graph_def(), + export_path=self._info_dir, + ) + + return new_batch + + return old_dataset.map(slice_map_func) + + def _find_subgraph_in_and_out( + self, + sub_graph_ops: Set[Operation], + ) -> Tuple[Dict[Operation, Set[Operation]], Dict[Operation, Set[Operation]]]: + """Find the input and output relationship of sliced subgraph. + + Args: + sub_graph_ops: The operation set that has been sliced. + + Returns: + in_op_to_edge_ops: The input relationship of sliced subgraph. + out_op_to_edge_ops: The output relationship of sliced subgraph. + """ + + in_op_to_edge_ops = dict() + out_op_to_edge_ops = dict() + + for base_node in sub_graph_ops: + self._update_subgraph_in(base_node, in_op_to_edge_ops, sub_graph_ops) + self._update_subgraph_out(base_node, out_op_to_edge_ops, sub_graph_ops) + + logger.info("Got input relationship of extracted subgraph: %s", in_op_to_edge_ops) + logger.info("Got output relationship of extracted subgraph: %s", out_op_to_edge_ops) + return in_op_to_edge_ops, out_op_to_edge_ops + + def _find_old_get_next(self, sliceable_ops: Set[Operation]) -> Operation: + """Find the old 'IteratorGetNext' operation. + + Args: + sliceable_ops: The operation set that can be sliced. + + Returns: + old_get_next: The old 'IteratorGetNext' operation. + """ + + old_get_next = utils.upward_bfs_op(sliceable_ops, AnchorIteratorOp.ITERATOR_GET_NEXT.value) + + self._full_graph.add_to_collection(DeprecatedOp.DEPRECATED_ITERATOR_GET_NEXT, old_get_next) + logger.info("Old 'IteratorGetNext' operation has been deprecated now.") + + return old_get_next + + def _find_old_dataset(self, get_next: Operation, is_training: bool) -> DatasetV1Adapter: + """Find the old dataset that needs to be mapped. + + Due to the different iterator types, the search method is different. + 1. If the iterator type is 'MakeIterator', this func will exec upward bfs search through get_next. + 2. If the iterator type is 'OneShotIterator', this func will fetch all operation in 'self._full_graph', then + filter out the 'PrefetchDataset' operation. This diff is caused by the isolation of 'OneShotIterator' and the + 'PrefetchDataset'. + + Args: + get_next: The old 'IteratorGetNext' operation. + is_training: Whether the slicing is for training graph or not. + + Returns: + old_dataset: The old dataset that needs to be mapped. + """ + + tgt_trans_dataset = None + try: + tgt_trans_dataset = utils.find_trans_dataset(self._full_graph, get_next) + except (ValueError, TypeError, RuntimeError) as err: + trans_datasets = [ + op for op in self._full_graph.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name + ] + trans_datasets = list( + filter( + lambda op: op not in tf.compat.v1.get_collection(DeprecatedOp.DEPRECATED_PREFETCH_DATASET), + trans_datasets, + ) + ) + sorted_datasets = sorted(trans_datasets, key=lambda op: op.name) + + if len(trans_datasets) == 1: + tgt_trans_dataset = sorted_datasets[0] + elif is_training and len(sorted_datasets) == 2: + tgt_trans_dataset = sorted_datasets[0] + elif not is_training and len(sorted_datasets) == 2: + tgt_trans_dataset = sorted_datasets[0] + else: + raise RuntimeError(f"target transformation dataset not found, got datasets: {trans_datasets}.") from err + except Exception as err: + raise RuntimeError(f"the dataset was not found, the error is `{err}`.") from err + + if not tgt_trans_dataset.outputs: + raise ValueError(f"the length of the outputs of target op `{tgt_trans_dataset}` is 0.") + logger.info("Find target op `%s`, and output is `%s`.", tgt_trans_dataset.name, tgt_trans_dataset.outputs) + + # WARN: Couple with modifier module, global collection used for filtering deprecated prefetch dataset. + self._full_graph.add_to_collection(DeprecatedOp.DEPRECATED_PREFETCH_DATASET, tgt_trans_dataset) + old_dataset = utils.find_target_instance_dataset(self._full_graph, tgt_trans_dataset.outputs[0]) + + return old_dataset + + def _clone_subgraph_into_funcgraph( + self, + sliced_ops: Set[Operation], + in_op_to_edge_ops: Set[Operation], + out_op_to_edge_ops: Set[Operation], + batch: Tuple[Dict[str, Union[Tensor, SparseTensor, Dict]]], + ) -> Dict[str, Union[Tensor, SparseTensor, Dict]]: + """Clone the sliced subgraph into a new funcgraph. + + Args: + sliced_ops: The operation set that has been sliced. + in_op_to_edge_ops: The input relationship of sliced subgraph. + out_op_to_edge_ops: The output relationship of sliced subgraph. + batch: The original batch layout of old dataset. + + Returns: + new_batch: The new batch layout of new dataset. + """ + + topo_subgraph_list = self._topo_sort_sliced_ops(sliced_ops) + + node_mapping = {} # subgraph-node -> funcgraph-node + tensor_mapping = {} # subgraph-tensor -> funcgraph-tensor + for in_op, edge_ops in in_op_to_edge_ops.items(): + self._get_mapping_for_subgraph_in(in_op, edge_ops, tensor_mapping) + for old_op in topo_subgraph_list: + self._get_mapping_for_subgraph(old_op, node_mapping, tensor_mapping) + + logger.info("Got node_mapping: %s", node_mapping) + logger.info("Got tensor_mapping: %s", tensor_mapping) + + ordered_output_tensors = self._sort_sliced_graph_outputs(out_op_to_edge_ops) + extra_output_tensor = self._get_mapped_tensor(tensor_mapping, ordered_output_tensors) + + if not isinstance(batch, tuple): + batch = (batch,) + + new_batch = batch[0] + for tensor in extra_output_tensor: + next_last_key = f"{sorted(new_batch)[-1]}_" + new_batch[next_last_key] = tensor + + logger.debug("Got new batch layout: %s.", new_batch) + return new_batch + + def _make_new_get_next( + self, + old_get_next: Operation, + new_dataset: DatasetV1Adapter, + ) -> Operation: + """Make new 'IteratorGetNext' operation. + + 1. This func will automatically detect the iterator type of the old dataset, and then make 'IteratorGetNext' + from the corresponding iterator. + 2. Only 'MakeIterator' and 'OneShotIterator' are available now. + + Args: + old_get_next: The old 'IteratorGetNext' operation. + new_dataset: The new dataset which contains sliced subgraph and corresponding additional outputs. + + Returns: + new_get_next: The new 'IteratorGetNext' operation. + """ + + if not old_get_next.outputs: + raise RuntimeError("no available tensor in the dataset. Please check the dataset and data processing.") + + iter_type = None + if old_get_next.inputs: + iter_type = old_get_next.inputs[0].op.type + if iter_type == AnchorIteratorOp.ITERATOR_V2.value: + iter_type = utils.find_make_iterator_op(self._full_graph, old_get_next.outputs[0]).type + if iter_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value): + raise RuntimeError( + f"only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " + f"but the current iterator is `{iter_type}`." + ) + logger.info("The iterator type of old dataset is %s.", iter_type) + + if iter_type == AnchorIteratorOp.MAKE_ITERATOR.value: + new_iterator = tf.compat.v1.data.make_initializable_iterator(new_dataset) + else: + new_iterator = tf.compat.v1.data.make_one_shot_iterator(new_dataset) + logger.info("Got new iterator: %s from dataset %s.", new_iterator, new_dataset) + + new_batch_name = "{}/{}".format( + NoGradSubgraphSlicer._SLICED_OP_NAME_PREFIX, AnchorIteratorOp.ITERATOR_GET_NEXT.value + ) + new_batch = new_iterator.get_next(name=new_batch_name) + + # WARN: Couple with user model, this collection has been addded manually. + if "timestamp" in new_batch.keys(): + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, new_batch["timestamp"]) + + try: + new_batch_tensor = new_batch + while not isinstance(new_batch_tensor, NoGradSubgraphSlicer._VALID_TENSOR_CLASS): + if isinstance(new_batch_tensor, tuple): + new_batch_tensor = new_batch_tensor[0] + elif isinstance(new_batch_tensor, dict): + new_batch_tensor = list(new_batch_tensor.values()) + elif isinstance(new_batch_tensor, list): + new_batch_tensor = new_batch_tensor[0] + elif isinstance(new_batch_tensor, NoGradSubgraphSlicer._VALID_TENSOR_CLASS): + break + else: + raise RuntimeError(f"batch value {new_batch_tensor} of {type(new_batch_tensor)} is not supported.") + except IndexError as err: + raise IndexError("cannot find a tensor from given batch.") from err + + new_get_next = utils.upward_bfs_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value) + + logger.info("Got old_new_get_next: %s.", new_get_next) + return new_get_next + + def _replace_get_next( + self, + old_get_next: Operation, + new_get_next: Operation, + out_op_to_edge_ops: Dict[Operation, Set[Operation]], + sliced_ops: Set[Operation], + ) -> None: + """Replace the old 'IteratorGetNext' operation with the new one. + + 1. This func will update the consumer of the old 'IteratorGetNext' operation to the new one. + 2. This func will update the consumer of the output tensors of the sliced subgraph to the new one. + + Args: + old_get_next: The old 'IteratorGetNext' operation. + new_get_next: The new 'IteratorGetNext' operation. + out_op_to_edge_ops: The output relationship of sliced subgraph. + sliced_ops: The operation set that has been sliced. + """ + + for t in old_get_next.outputs: + self._update_old_get_next_consumer(t, new_get_next, sliced_ops) + + next_offset = len(old_get_next.outputs) - 1 + sorted_outputs = self._sort_sliced_graph_outputs(out_op_to_edge_ops) + + for t in sorted_outputs: + next_offset += 1 + self._update_sliced_graph_consumer(t, new_get_next, next_offset) + + def _update_old_get_next_consumer( + self, old_get_next_output: Tensor, new_get_next: Operation, sliced_ops: Set[Operation] + ) -> None: + """Update the consumer of the old 'IteratorGetNext' operation to the new one. + + Args: + old_get_next_output: The output tensor of the old 'IteratorGetNext' operation. + new_get_next: The new 'IteratorGetNext' operation. + sliced_ops: The operation set that has been sliced. + """ + + old_tensor_name = old_get_next_output.name + output_index = old_tensor_name.split(":")[-1] + new_tensor_name = f"{new_get_next.name}:{output_index}" + new_tensor = self._full_graph.get_tensor_by_name(new_tensor_name) + + old_tensor_consumers = self._get_tensor_consumers(old_get_next_output) + for consumer in old_tensor_consumers: + if consumer in sliced_ops: + logger.debug("Ignore consumer: %s in sliced operations.", consumer.name) + continue + for i, t in enumerate(consumer.inputs): + if t != old_get_next_output: + logger.debug( + "Ignore input %s of consumer %s, cause it not output of 'IteratorGetNext'.", + t.name, + consumer.name, + ) + continue + consumer._update_input(i, new_tensor) + logger.debug( + "Succeed replace old input %s of consumer %s to new input %s.", + old_tensor_name, + consumer.name, + new_tensor, + ) + + def _update_sliced_graph_consumer( + self, sliced_graph_output: Tensor, new_get_next: Operation, next_offset: int + ) -> None: + """Update the consumer of the output tensors of the sliced subgraph to the new one. + + The outputs of the sliced subgraph are not the original outputs of 'IteratorGetNext'. Thus, next offset should + trace the last index of outputs of new 'IteratorGetNext'. + + Args: + sliced_graph_output: The output tensor of the sliced subgraph. + new_get_next: The new 'IteratorGetNext' operation. + next_offset: The last offset of the new 'IteratorGetNext' operation. + """ + + new_tensor_name = f"{new_get_next.name}:{next_offset}" + new_tensor = self._full_graph.get_tensor_by_name(new_tensor_name) + + old_tensor_consumers = self._get_tensor_consumers(sliced_graph_output) + for consumer in old_tensor_consumers: + if consumer.type in NoGradSubgraphSlicer._INVALID_CONSUMER_OP_TYPE: + logger.debug("Ignore invalid consumer: %s.", consumer.name) + continue + for i, t in enumerate(consumer.inputs): + if t != sliced_graph_output: + logger.debug( + "Ignore input %s of consumer %s, cause it not output of sliced graph.", + t.name, + consumer.name, + ) + continue + consumer._update_input(i, new_tensor) + logger.debug( + "Succeed replace old input %s of consumer %s to new input %s.", + sliced_graph_output, + consumer.name, + new_tensor, + ) + + +@para_checker_decorator( + check_option_list=[ + ("op_types", ClassValidator, {"classes": (list,)}), + ("full_graph", ClassValidator, {"classes": (Graph, type(None))}), + ("info_dir", ClassValidator, {"classes": (str,)}), + ] +) +class LookupSubgraphSlicer(NoGradSubgraphSlicer): + def __init__(self, op_types: List[str], full_graph: Graph = None, info_dir: str = "lookup_slicing") -> None: + """Initialize LookupSubgraphSlicer. + Args: + op_types: The list of operation types to be sliced in lookup subgraph. + full_graph: The full graph to be sliced. If None, the default graph will be used. + info_dir: The directory to save the slicing information. Defaults to "lookup_slicing". + """ + super().__init__(full_graph, info_dir) + if not op_types: + raise ValueError("no slicing operation types specified!") + self._op_types = set(op_types) + + def summarize(self) -> None: # pragma: no cover + all_tgt_ops = self._find_all_tgt_ops() + (train_sliceable_tgt_ops, eval_sliceable_tgt_ops) = self._find_sliceable_tgt_ops() + all_sliceable_tgt_ops = train_sliceable_tgt_ops | eval_sliceable_tgt_ops + + result = {"Operation Type": [], "Total Num": [], "Sliceable Num": [], "Sliceable Ratio": []} + + for op_type in self._op_types: + tgt_ops = set(filter(lambda op: op.type == op_type, all_tgt_ops)) + sliceable_tgt_ops = set(filter(lambda op: op.type == op_type, all_sliceable_tgt_ops)) + + total_num = len(tgt_ops) + sliceable_num = len(sliceable_tgt_ops) + + try: + sliceable_ratio = sliceable_num / total_num + except ZeroDivisionError: + logger.warning("No target operaiton types '%s' found in given graph.", self._op_types) + + result["Operation Type"].append(op_type) + result["Total Num"].append(total_num) + result["Sliceable Num"].append(sliceable_num) + result["Sliceable Ratio"].append(sliceable_ratio) + + result_df = pd.DataFrame(data=result) + file = "{}/{}".format(self._info_dir, NoGradSubgraphSlicer._SLICING_SUMMARY_NAME) + result_df.to_csv(file, sep=",") + + logger.info("Summary of slicing:\n%s", result_df) + + def slice(self) -> None: + utils.export_pb_graph( + file_name=NoGradSubgraphSlicer._UNSLICED_FULL_GRAPH_NAME, + dump_graph=True, + graph_def=self._full_graph.as_graph_def(), + export_path=self._info_dir, + ) + + (train_sliceable_ops, eval_sliceable_ops) = self._find_sliceable_tgt_ops() + + if train_sliceable_ops: + logger.info("Start to slice training lookup subgraph.") + self._slice_ops(train_sliceable_ops, is_training=True) + + if eval_sliceable_ops: + logger.info("Start to slice evaluation lookup subgraph.") + self._slice_ops(eval_sliceable_ops, is_training=False) + + utils.export_pb_graph( + file_name=NoGradSubgraphSlicer._SLICED_FULL_GRAPH_NAME, + dump_graph=True, + graph_def=self._full_graph.as_graph_def(), + export_path=self._info_dir, + ) + + def _find_all_tgt_ops(self) -> Set[Operation]: + """Found all operations of specific types in full graph.""" + all_tgt_ops = set() + all_ops = self._full_graph.get_operations() + + for op in all_ops: + if op.type not in self._op_types: + continue + all_tgt_ops.add(op) + + return all_tgt_ops + + def _find_sliceable_tgt_ops(self) -> Tuple[Set[Operation], Set[Operation]]: + """Found sliceable operations of given types in lookup subgraph.""" + + # WARN: Couple with mx_rec::core::embedding module. + lookup_keys = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) + + train_base_ops = set() + eval_base_ops = set() + for t in lookup_keys: + if BaseSparseEmbedding.get_anchor_attribute(t, ASCAnchorAttr.IS_TRAINING): + train_base_ops.add(t.op) + else: + eval_base_ops.add(t.op) + + def find_sliceable_ops(base_ops): + min_dep_ops = self._find_min_dep_ops(base_ops) + + sliceable_ops = set() + for op in min_dep_ops: + if not self._validate_op(op): + continue + if op.type not in self._op_types: + continue + sliceable_ops.add(op) + + return sliceable_ops + + train_sliceable_ops = find_sliceable_ops(train_base_ops) + eval_sliceable_ops = find_sliceable_ops(eval_base_ops) + + logger.debug("Found sliceable operations in training lookup subgraph: %s.", train_sliceable_ops) + logger.debug("Found sliceable operations in evaluation lookup subgraph: %s.", eval_sliceable_ops) + return (train_sliceable_ops, eval_sliceable_ops) + + +@para_checker_decorator( + check_option_list=[ + ("full_graph", ClassValidator, {"classes": (Graph, type(None))}), + ("info_dir", ClassValidator, {"classes": (str,)}), + ] +) +class OrphanLookupKeySlicer(NoGradSubgraphSlicer): + def __init__(self, full_graph: Graph = None, info_dir: str = "orphan_slicing") -> None: + """Initialize OrphanLookupKeySlicer. + Args: + full_graph: The full graph to be sliced. If None, the default graph will be used. + info_dir: The directory to save the slicing information. Defaults to "orphan_slicing". + """ + super().__init__(full_graph, info_dir) + + def summarize(self) -> None: # pragma: no cover + (train_sliceable_ops, _) = self._find_sliceable_tgt_ops() + + if len(train_sliceable_ops) == 0: + return + + result = {"Operation Type": [], "Operation Name": []} + for op in train_sliceable_ops: + result["Operation Type"].append(op.type) + result["Operation Name"].append(op.name) + + result_df = pd.DataFrame(data=result) + file = "{}/{}".format(self._info_dir, NoGradSubgraphSlicer._SLICING_SUMMARY_NAME) + result_df.to_csv(file, sep=",") + + logger.info("Summary of slicing:\n%s", result_df) + + def slice(self) -> None: + utils.export_pb_graph( + file_name=NoGradSubgraphSlicer._UNSLICED_FULL_GRAPH_NAME, + dump_graph=True, + graph_def=self._full_graph.as_graph_def(), + export_path=self._info_dir, + ) + + (train_sliceable_ops, eval_sliceable_ops) = self._find_sliceable_tgt_ops() + + if train_sliceable_ops: + logger.info("Start to slice training lookup subgraph.") + self._slice_ops(train_sliceable_ops, is_training=True) + + if eval_sliceable_ops: + logger.info("Start to slice evaluation lookup subgraph.") + self._slice_ops(eval_sliceable_ops, is_training=False) + + utils.export_pb_graph( + file_name=NoGradSubgraphSlicer._SLICED_FULL_GRAPH_NAME, + dump_graph=True, + graph_def=self._full_graph.as_graph_def(), + export_path=self._info_dir, + ) + + def _slice_ops(self, sliceable_ops: Set[Operation], is_training: bool) -> None: + """Override the '_slice_ops' protected method of super class.""" + + sliced_ops = self._find_min_dep_ops(sliceable_ops) + in_op_to_edge_ops, out_op_to_edge_ops = self._find_subgraph_in_and_out(sliced_ops) + + all_get_nexts = [ + op for op in self._full_graph.get_operations() if op.type == AnchorIteratorOp.ITERATOR_GET_NEXT.value + ] + alive_get_nexts = list( + filter( + lambda op: op not in self._full_graph.get_collection(DeprecatedOp.DEPRECATED_ITERATOR_GET_NEXT), + all_get_nexts, + ) + ) + alive_get_nexts = sorted(alive_get_nexts, key=lambda op: op.name) + + old_get_next = None + if len(alive_get_nexts) == 1: + old_get_next = alive_get_nexts[0] + else: + old_get_next = alive_get_nexts[0] if is_training else alive_get_nexts[1] + + old_dataset = self._find_old_dataset(old_get_next, is_training) + + new_dataset = self._make_new_dataset(old_dataset, sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops) + new_dataset = new_dataset.prefetch(0) + + new_get_next = self._make_new_get_next(old_get_next, new_dataset) + self._replace_get_next(old_get_next, new_get_next, out_op_to_edge_ops, sliced_ops) + + def _find_sliceable_tgt_ops(self) -> Tuple[Set[Operation], Set[Operation]]: + """Found orhpan keys' additional identity operation in lookup subgraph.""" + + # WARN: Couple with mx_rec::core::embedding module. + lookup_keys = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) + + train_base_ops = set() + eval_base_ops = set() + for t in lookup_keys: + if BaseSparseEmbedding.get_anchor_attribute(t, ASCAnchorAttr.IS_TRAINING): + train_base_ops.add(t.op) + else: + eval_base_ops.add(t.op) + + def find_sliceable_ops(base_ops): + min_dep_ops = self._find_min_dep_ops(base_ops) + + sliceable_ops = set() + for op in min_dep_ops: + if not self._validate_op(op): + continue + if ORPHAN_LOOKUP_KEY_PREFIX not in op.name: + continue + sliceable_ops.add(op) + + return sliceable_ops + + train_sliceable_ops = find_sliceable_ops(train_base_ops) + eval_sliceable_ops = find_sliceable_ops(eval_base_ops) + + logger.debug("Found sliceable operations in training lookup subgraph: %s.", train_sliceable_ops) + logger.debug("Found sliceable operations in evaluation lookup subgraph: %s.", eval_sliceable_ops) + return (train_sliceable_ops, eval_sliceable_ops) diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index c010d80d8bdfcb0f17abdf361c8314f6d96a782d..17f071ac6bb737148105c2304ebaef3ccc05b14e 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -17,53 +17,136 @@ import os from collections import defaultdict -from typing import List, Dict, Union +from typing import List, Dict, Set, Union, DefaultDict, Tuple import tensorflow as tf -from tensorflow import Operation, Tensor +from tensorflow import Operation, Tensor, Graph from tensorflow.core.framework.graph_pb2 import GraphDef +from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from tensorflow.python.framework.errors_impl import InvalidArgumentError +from tensorflow.python.ops import control_flow_ops +from mx_rec.graph.constants import AnchorDatasetOp, AnchorIteratorOp from mx_rec.constants.constants import ASCAnchorAttr, DUMP_MIDIFY_GRAPH_FILE_MODE from mx_rec.core.embedding import BaseSparseEmbedding -from mx_rec.graph.graph_typing import ReplacementSpec from mx_rec.util.log import logger -def check_input_list(objs: Union[object, List[object]], obj_type: type) -> Union[object, List[object]]: - if isinstance(objs, obj_type): - objs = [objs] +def find_trans_dataset(graph: Graph, get_next: Operation) -> Operation: + """Find the transformation dataset through 'get_next'. - if isinstance(objs, list): - for tensor in objs: - if not isinstance(tensor, obj_type): - raise ValueError(f"Given input parameter must be a {obj_type} or a list of {obj_type}") + Args: + get_next: The old 'IteratorGetNext' operation. + + Returns: + trans_dataset: The target transformation dataset. + """ + + if get_next.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: + raise TypeError(f"operation '{get_next}' must be one instance of 'IteratorGetNext'.") + + make_iter = find_make_iterator_op(graph, get_next.outputs[0]) - return objs + trans_dataset = None + if tf.__version__.startswith("1"): + optimize_dataset_op = upward_bfs_op(make_iter, AnchorDatasetOp.MODEL_DATASET.value) + trans_dataset = find_parent_op(optimize_dataset_op) + if not trans_dataset: + raise RuntimeError("parent operation of 'ModelDataset' was not found.") + if trans_dataset[0].type != AnchorDatasetOp.OPTIMIZE_DATASET.value: + raise TypeError(f"operation 'OptimizeDataset' was not found.") + trans_dataset = trans_dataset[0] + else: + trans_dataset = upward_bfs_op(make_iter, AnchorDatasetOp.PREFETCH_DATASET.value) + + return trans_dataset + + +def find_make_iterator_op(graph: Graph, batch_tensor: Tensor) -> Operation: + operations = graph.get_operations() + for each_op in operations: + for input_tensor in batch_tensor.op.inputs: + if ( + input_tensor.op.outputs + and input_tensor.op.outputs[0] in list(each_op.inputs) + and each_op.type == AnchorIteratorOp.MAKE_ITERATOR.value + ): + logger.debug("Op MakeIterator '%s' was found.", each_op.name) + return each_op + + raise ValueError(f"operation `MakeIterator` cannot be found.") def find_parent_op(operator: Operation) -> List[Operation]: parent_ops = [] for input_tensor in operator.inputs: parent_op = input_tensor.op - if isinstance(parent_op, tf.Operation): + if isinstance(parent_op, Operation): parent_ops.append(parent_op) return parent_ops +def upward_bfs_op(base_ops: Union[Operation, Set[Operation], List[Operation]], tgt_op_type: str) -> Operation: + if not isinstance(base_ops, (set, list)): + base_ops = [base_ops] + + parent_ops = base_ops + while True: + for parent_op in parent_ops: + if parent_op.type == tgt_op_type: + return parent_op + base_ops = parent_ops + parent_ops = [] + for base_op in base_ops: + parent_ops.extend(find_parent_op(base_op)) + if not parent_ops: + raise ValueError(f"target operation '{tgt_op_type}'' was not found.") + + +def find_target_instance_dataset(graph: Graph, variant_tensor: Tensor) -> DatasetV1Adapter: + dataset_instance_list = graph.get_collection("dataset_group") + for ins in dataset_instance_list: + if ins._variant_tensor == variant_tensor: + if not isinstance(ins, DatasetV1Adapter): + ins = ins._input_dataset + logger.debug("Find target instance '%s', whose variant_tensor is '%s'.", ins, variant_tensor) + if not isinstance(ins.element_spec, dict) and not ( + isinstance(ins.element_spec, (list, tuple)) + and len(ins.element_spec) == 2 + and isinstance(ins.element_spec[0], dict) + ): + raise NotImplementedError("the found dataset does not return a valid layout.") + + return ins + + raise LookupError(f"Can not find target instance, whose variant_tensor is '{variant_tensor}' respectively.") + + +def check_and_force_list(obj: Union[object, List[object]], obj_type: type) -> Union[object, List[object]]: + if isinstance(obj, obj_type): + obj = [obj] + + if isinstance(obj, list): + for tensor in obj: + if not isinstance(tensor, obj_type): + raise ValueError(f"Given input parameter must be a {obj_type} or a list of {obj_type}") + + return obj + + def check_cutting_points(cutting_point_list: List[Tensor]): for tensor in cutting_point_list: - if not isinstance(tensor, tf.Tensor): + if not isinstance(tensor, Tensor): raise TypeError(f"Collection ASCEND_CUTTING_POINT can only contain Tensors, but '{tensor}' was found.") if tensor.op.type != "Identity": raise ValueError(f"Cutting point can only be the output of an Operator 'Identity'.") -def record_ops_to_replace(src_op: Operation) -> ReplacementSpec: +def record_ops_to_replace(graph: Graph, src_op: Operation) -> DefaultDict[Tensor, List[Tuple[int, Operation]]]: replacement_specs = defaultdict(list) output_list = src_op.outputs - op_list = tf.compat.v1.get_default_graph().get_operations() + op_list = graph.get_operations() for tensor in output_list: for operator in op_list: if tensor in operator.inputs: @@ -73,74 +156,83 @@ def record_ops_to_replace(src_op: Operation) -> ReplacementSpec: return replacement_specs -def replace_anchor(replacement_specs: ReplacementSpec, new_tensor_list: List[Tensor]): +def replace_anchor(replacement_specs: DefaultDict[Tensor, List[Tuple[int, Operation]]], new_tensor_list: List[Tensor]): if len(replacement_specs) != len(new_tensor_list): - raise ValueError(f"Given replacement_specs and new_tensor_list must have the same length. " - f"replacement_specs: {replacement_specs}, new_tensor_list: {new_tensor_list}") + raise ValueError( + f"Given replacement_specs and new_tensor_list must have the same length. " + f"replacement_specs: {replacement_specs}, new_tensor_list: {new_tensor_list}" + ) for tensor_idx, (old_tensor, items) in enumerate(replacement_specs.items()): for input_idx, operator in items: try: operator._update_input(input_idx, new_tensor_list[tensor_idx]) except InvalidArgumentError as err: - logger.info("The replacement specs keys (old batch) is: %s. \n\t\t The new_tensor_list is: %s.", - replacement_specs.keys(), new_tensor_list) - raise RuntimeError(f"Cannot update edge, old tensor: {old_tensor}, " - f"new tensor: {new_tensor_list[tensor_idx]}.") from err + logger.info( + "The replacement specs keys (old batch) is: %s. \n\t\t The new_tensor_list is: %s.", + replacement_specs.keys(), + new_tensor_list, + ) + raise RuntimeError( + f"Cannot update edge, old tensor: {old_tensor}, " f"new tensor: {new_tensor_list[tensor_idx]}." + ) from err -def export_pb_graph(file_name: str, - dump_graph: bool = False, - graph_def: GraphDef = None, - export_path: str = "./export_graph", - as_text: bool = False): +def replace_anchor_control(graph: Graph, place_holder_control: tf.Operation, real_anchor: Tensor): """ - Save tensorflow graph before and after modifier graph - :param file_name: FileName of the graph - :param dump_graph: Is serialize graph or not - :param graph_def: A Graph or a GraphDef protocol buffer. - :param export_path: Directory where to write the graph. - This can refer to remote filesystems, such as Google Cloud Storage (GCS). - :param as_text: If True, writes the graph as an ASCII proto - :return: None + 将place_holder_control替换为入参real_anchor. + + Args: + place_holder_control: control op + real_anchor: 用来替换打桩节点的tensor + + Returns: None + """ - if dump_graph: - dir_path = os.path.dirname(os.path.join(export_path, file_name)) - os.makedirs(dir_path, mode=DUMP_MIDIFY_GRAPH_FILE_MODE, exist_ok=True) - graph_def = graph_def if graph_def else tf.compat.v1.get_default_graph().as_graph_def() - tf.io.write_graph(graph_def, export_path, file_name, as_text) + if place_holder_control is None: + raise RuntimeError( + f"Node place_holder_control does not exist. Check whether the sparse lookup interface " + f"is correctly invoked." + ) + # find the op with stub node as the input + replacement_specs_for_anchor_vec = record_control_to_replace(graph, place_holder_control) + # replace anchor_vec with anchor + replace_control_anchor(replacement_specs_for_anchor_vec, real_anchor) -def make_sorted_key_to_tensor_list( - element_spec: List[Dict[str, Tensor]], - sorted_keys: List[str], - prefix: str = "" -) -> List[str]: - if isinstance(element_spec, tf.TensorSpec): - sorted_keys.append(prefix) - return sorted_keys - elif isinstance(element_spec, dict): - for key, item in element_spec.items(): - if not isinstance(key, str): - raise TypeError(f"The key of element_spec must be a string.") +def record_control_to_replace(graph: Graph, src_op: Operation) -> DefaultDict[Tensor, List[Tuple[int, Operation]]]: + replacement_specs = defaultdict(list) + op_list = graph.get_operations() + for operator in op_list: + if src_op in operator.control_inputs: + input_index = operator.control_inputs.index(src_op) + replacement_specs[src_op].append((input_index, operator)) - prefix = "{0}_{1}".format(prefix, key) - sorted_keys = make_sorted_key_to_tensor_list(item, sorted_keys, prefix=prefix) - sorted_keys = sorted(sorted_keys) - return sorted_keys + return replacement_specs - elif isinstance(element_spec, (list, tuple)): - for idx, item in enumerate(element_spec): - prefix = "{0}_{1}".format(prefix, str(idx)) - sorted_keys = make_sorted_key_to_tensor_list(item, sorted_keys, prefix=prefix) - sorted_keys = sorted(sorted_keys) - return sorted_keys - raise TypeError(f"Given element_spec, whose type is {type(element_spec)}, is invalid.") +def replace_control_anchor( + replacement_specs: DefaultDict[Tensor, List[Tuple[int, Operation]]], new_tensor_list: List[Tensor] +): + + for tensor_idx, (old_tensor, items) in enumerate(replacement_specs.items()): + for _, operator in items: + try: + control_op = control_flow_ops.group(new_tensor_list) + operator._add_control_input(control_op) + except InvalidArgumentError as err: + logger.info( + "The replacement control specs keys (old batch) is: %s. \n\t\t The new_tensor_list is: %s.", + replacement_specs.keys(), + new_tensor_list, + ) + raise RuntimeError( + f"Cannot update edge, old tensor: {old_tensor}, " f"new tensor: {new_tensor_list[tensor_idx]}." + ) from err -def replace_anchor_vec(cutting_point: Tensor, attribute: ASCAnchorAttr, anchor: Tensor): +def replace_anchor_vec(graph: Graph, cutting_point: Tensor, attribute: ASCAnchorAttr, anchor: Tensor): """ 根据打桩节点的名字找到以此为输入的op,并将该op的输入替换为入参anchor. @@ -156,23 +248,61 @@ def replace_anchor_vec(cutting_point: Tensor, attribute: ASCAnchorAttr, anchor: # get stub node anchor_vec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, attribute) if anchor_vec is None: - raise RuntimeError(f"Node `{attribute.value}` does not exist. Check whether the sparse lookup interface " - f"is correctly invoked.") + raise RuntimeError( + f"Node `{attribute.value}` does not exist. Check whether the sparse lookup interface " + f"is correctly invoked." + ) # find the op with stub node as the input - replacement_specs_for_anchor_vec = record_ops_to_replace(anchor_vec.op) + replacement_specs_for_anchor_vec = record_ops_to_replace(graph, anchor_vec.op) # replace anchor_vec with anchor replace_anchor(replacement_specs_for_anchor_vec, [anchor]) -def tag_orphan_ids(ids: tf.Tensor) -> tf.Tensor: +def make_sorted_key_to_tensor_list( + element_spec: List[Dict[str, Tensor]], sorted_keys: List[str], prefix: str = "" +) -> List[str]: + if isinstance(element_spec, tf.TensorSpec): + sorted_keys.append(prefix) + return sorted_keys + elif isinstance(element_spec, dict): + for key, item in element_spec.items(): + if not isinstance(key, str): + raise TypeError(f"The key of element_spec must be a string.") + + prefix = "{0}_{1}".format(prefix, key) + sorted_keys = make_sorted_key_to_tensor_list(item, sorted_keys, prefix=prefix) + sorted_keys = sorted(sorted_keys) + return sorted_keys + + elif isinstance(element_spec, (list, tuple)): + for idx, item in enumerate(element_spec): + prefix = "{0}_{1}".format(prefix, str(idx)) + sorted_keys = make_sorted_key_to_tensor_list(item, sorted_keys, prefix=prefix) + sorted_keys = sorted(sorted_keys) + return sorted_keys + + raise TypeError(f"Given element_spec, whose type is {type(element_spec)}, is invalid.") + + +def export_pb_graph( + file_name: str, + dump_graph: bool = False, + graph_def: GraphDef = None, + export_path: str = "./export_graph", + as_text: bool = True, +): """ - 将孤儿ids使用identity操作创建ACG_PUSH_NODE前缀命名的标记节点,以便在PushOps时能找到。 + Save tensorflow graph before and after modifier graph + :param file_name: FileName of the graph + :param dump_graph: Is serialize graph or not + :param graph_def: A Graph or a GraphDef protocol buffer. + :param export_path: Directory where to write the graph. + This can refer to remote filesystems, such as Google Cloud Storage (GCS). + :param as_text: If True, writes the graph as an ASCII proto + :return: None """ - graph_def = tf.compat.v1.get_default_graph().as_graph_def() - subgraph = tf.compat.v1.graph_util.extract_sub_graph(graph_def, [ids.op.name]) - for node in subgraph.node: - if node.op == 'IteratorGetNext': - return ids - new_ids = tf.identity(ids, name=f"ACG_PUSH_NODE_{ids.op.name}") - logger.info('Tag orphan op node: %s with %s.', ids, new_ids) - return new_ids + if dump_graph: + dir_path = os.path.dirname(os.path.join(export_path, file_name)) + os.makedirs(dir_path, mode=DUMP_MIDIFY_GRAPH_FILE_MODE, exist_ok=True) + graph_def = graph_def if graph_def else tf.compat.v1.get_default_graph().as_graph_def() + tf.io.write_graph(graph_def, export_path, file_name, as_text) diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index d99be3b32b8ccac9a63742c896deafcc6733f64f..df1fe2a3072682e8328ebde299eb6ec0688fad60 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -21,12 +21,12 @@ from __future__ import print_function from collections import defaultdict +from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import adagrad, training_ops -from tensorflow.python.training import slot_creator -from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.optimizers.base import CustomizedOptimizer, control_update_op_decorator from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -76,30 +76,8 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): initial_accumulator_value=initial_accumulator_value, use_locking=use_locking, name=self.unique_name) - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - return new_slot_variable - - accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") - ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accumulator.name) - named_slot_key = (var.op.graph, var.op.name) - table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) - ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, - self.optimizer_type, - {"accumulator": accumulator}) - return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot + self._slot_num = 1 + self._derivative = 2 def get_slot_init_values(self): # return state value list of adagrad that needs to initialize in ASC DDR. @@ -119,6 +97,21 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): self._get_or_make_slot_with_initializer(var, init, var.get_shape(), dtype, "acc", acc_state_name) + def _apply_sparse_duplicate_indices(self, grad, var): + # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may + # introduce dynamic shape problem, if encounter that, please de-annotation the method below. + unique_local_grad, unique_keys = self.sum_same_id_gradients(grad=grad.values, var=var, is_expansion=False) + gradient_no_duplicate_indices = ops.IndexedSlices( + indices=unique_keys, + values=unique_local_grad, + dense_shape=grad.dense_shape) + return self._apply_sparse(gradient_no_duplicate_indices, var) + + def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + unique_local_grad, unique_keys = self.sum_same_id_gradients(grad=grad, var=handle, is_expansion=False) + return self._resource_apply_sparse(unique_local_grad, handle, unique_keys) + + @control_update_op_decorator def _apply_sparse(self, grad, var): acc = self.get_slot(var, "acc") return training_ops.sparse_apply_adagrad( @@ -127,6 +120,7 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): grad.indices, use_locking=self._use_locking) + @control_update_op_decorator def _resource_apply_sparse(self, grad, var, indices): acc = self.get_slot(var, "acc") return training_ops.resource_sparse_apply_adagrad( diff --git a/mx_rec/optimizers/adagrad_by_addr.py b/mx_rec/optimizers/adagrad_by_addr.py new file mode 100644 index 0000000000000000000000000000000000000000..72f1d86e93c982df8c34f77c97269a6129e27e03 --- /dev/null +++ b/mx_rec/optimizers/adagrad_by_addr.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import, division, print_function + +from typing import List + +import tensorflow as tf +from tensorflow.python.ops import math_ops +from tensorflow.python.training import adagrad +from tensorflow.python.training.optimizer import Optimizer + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.ops import import_host_pipeline_ops +from mx_rec.validator.validator import ( + FloatValidator, + StringValidator, + para_checker_decorator, +) + + +@para_checker_decorator( + check_option_list=[ + ("learning_rate", FloatValidator, {"min_value": 0.0, "max_value": 10.0}, ["check_value"]), + ( + "initial_accumulator_value", + FloatValidator, + {"min_value": 0.0, "max_value": 1.0}, + ["check_value_for_left_open_interval"], + ), + ("name", StringValidator, {"min_len": 1, "max_len": 200}, ["check_string_length"]), + ] +) +def create_hash_optimizer_by_address(learning_rate=0.001, initial_accumulator_value=0.9, name="Adagrad") -> Optimizer: + """Create an instance of adagrad hash optimizer. + + Args: + learning_rate: A `Tensor` or a floating point value. The learning rate. + initial_accumulator_value: A floating point value. Starting value for the accumulators, must be positive. + name: Optional name prefix for the operations created when applying gradients. Defaults to "Adagrad". + + Returns: + Adagrad hash optimizer instance + + Raises: + ValueError: If `use_dynamic_expansion` was not set. + """ + if not ConfigInitializer.get_instance().use_dynamic_expansion: + raise ValueError( + "dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly" + ) + optimizer = CustomizedAdagradByAddress( + learning_rate=learning_rate, + initial_accumulator_value=initial_accumulator_value, + name=name, + ) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer + + +class CustomizedAdagradByAddress(adagrad.AdagradOptimizer, CustomizedOptimizer): + def __init__( + self, + learning_rate: float, + initial_accumulator_value: float, + name="Adagrad", + ): + self.optimizer_type = "Adagrad" + self.optim_param_list = ["accumulator"] + super(CustomizedAdagradByAddress, self)._get_name(name=name) + super(CustomizedAdagradByAddress, self).__init__( + learning_rate=learning_rate, + initial_accumulator_value=initial_accumulator_value, + name=self.unique_name, + ) + self._epsilon = 1e-7 + self._slot_num = 1 + self._derivative = 2 + + def get_slot_init_values(self) -> List[float]: + # return state value list of adagrad that needs to initialize in ASC DDR. + return [self._initial_accumulator_value] + + def _apply_sparse(self, grad: tf.Tensor, var: tf.Tensor) -> tf.Operation: + grad, var = self.sum_same_id_gradients(grad=grad, var=var, is_expansion=True) + learning_rate_tensor = math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype) + epsilon = math_ops.cast(self._epsilon, grad.dtype.base_dtype) + + host_pipeline_ops = import_host_pipeline_ops() + dim = grad.shape.as_list()[-1] + + combined_tensor = host_pipeline_ops.embedding_lookup_by_address(var, embedding_dim=2 * dim, embedding_type=1) + split_length = [dim] + [dim] + split_tensors = tf.split(combined_tensor, split_length, axis=1) + + old_s_slice = split_tensors[1] + s_t_slice = old_s_slice + math_ops.square(grad) + + denominator_slice = math_ops.sqrt(s_t_slice + epsilon) + + update_list = [tf.divide(-learning_rate_tensor * grad, denominator_slice)] + [s_t_slice - old_s_slice] + update_tensor = tf.concat(update_list, axis=1) + var_update_op = host_pipeline_ops.embedding_update_by_address(var, update_tensor, update_type=0) + + return var_update_op + + def _create_slots(self, var_list: List[tf.Variable]): + # slot变量由lookup算子控制 跳过父类的实现 + pass diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index a5d68a704d9e130da48d0d8bb78182ff37f3be52..496296413e3d9d5cd3ce75bae8ad0f03af9e24de 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -21,10 +21,60 @@ from __future__ import print_function from collections import defaultdict +import tensorflow as tf from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.training.optimizer import _TensorProcessor +from mx_rec.core.asc.swap_args import SwapArgs +from mx_rec.constants.constants import ASCAnchorAttr +from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger +from mx_rec.util.communication.hccl_ops import get_rank_size + + +def get_restore_vector_second(table_name: str, max_lookup_vec_size: int) -> tf.Tensor: + """ + Get restore vector which is calculated after the second all2all + :param table_name: embedding table_name + :param max_lookup_vec_size: static shape + :return: the restore vector calculated after the second all2all + """ + channel_id = 0 + logger.debug('Channel %s_restore_second_%s was built for getnext', + table_name, channel_id) + with tf.compat.v1.variable_scope(table_name, reuse=tf.compat.v1.AUTO_REUSE): + restore_vector_second = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{table_name}_restore_second_{channel_id}')[0] + return restore_vector_second + + +def get_unique_keys(table_name: str, max_lookup_vec_size: int, is_expansion: bool) -> tf.Tensor: + """ + Get the global unique keys which is calculated after the second all2all + :param table_name: embedding table_name + :param max_lookup_vec_size: static shape + :param is_expansion: use dynamic expansion + :return: the global unique keys calculated after the second all2all + """ + channel_id = 0 + logger.debug('Channel %s_uniquekeys_%s was built for getnext', table_name, channel_id) + with tf.compat.v1.variable_scope(table_name, reuse=tf.compat.v1.AUTO_REUSE): + if is_expansion: + unique_keys = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{table_name}_uniquekeys_{channel_id}')[0] + return unique_keys + + unique_keys = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{table_name}_uniquekeys_{channel_id}')[0] + return unique_keys class CustomizedOptimizer: @@ -34,12 +84,43 @@ class CustomizedOptimizer: def __init__(self): self.unique_name = "" self.base_name = "" + self._slot_num = 0 # 优化器对应slot的个数 + self._derivative = 1 # 优化器阶数,如果不做全局去重可以数学等价,则为1阶,其余2阶 + + @property + def slot_num(self): + return self._slot_num + + @property + def derivative(self): + return self._derivative + + @staticmethod + def sum_same_id_gradients(grad, var, is_expansion): + if isinstance(var, ops.Tensor): + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_tensor(var) + table_name = table_instance.table_name + else: + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) + table_name = table_instance.table_name - def initialize_slots(self, var, table_instance): - raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + max_lookup_vec_size = None + use_static = ConfigInitializer.get_instance().use_static + if use_static: + send_count = table_instance.send_count + rank_size = get_rank_size() + max_lookup_vec_size = send_count * rank_size if send_count > 0 else None - def insert_slot(self, slot, named_slots_key, slot_name): - raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + with tf.compat.v1.variable_scope(str(ASCAnchorAttr.RESTORE_VECTOR_SECOND)): + restore_vector_second = get_restore_vector_second(table_name, max_lookup_vec_size) + + with tf.compat.v1.variable_scope(str(ASCAnchorAttr.UNIQUE_KEYS)): + unique_keys = get_unique_keys(table_name, max_lookup_vec_size, is_expansion) + + unique_local_grad = tf.compat.v1.unsorted_segment_sum(grad, + restore_vector_second, + array_ops.shape(unique_keys)[0]) + return unique_local_grad, unique_keys def get_slot_init_values(self): raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") @@ -63,6 +144,18 @@ def custom_update_op(self, opt, grad): raise RuntimeError("Only support g with type Tensor.") +def control_update_op_decorator(apply_sparse): + def wrapper(*args, **kwargs): + second_arg = args[2] if len(args) > 2 else None # index 2 input must be var + slot_control_ops = tf.no_op(name="place_holder_slot_control_op") + swap_args = SwapArgs() + swap_args.set_slot_control(var_name=second_arg, control_ops=slot_control_ops) + with tf.control_dependencies([slot_control_ops]): + result = apply_sparse(*args, **kwargs) + return result + return wrapper + + def patch_for_optimizer(): _TensorProcessor.update_op = custom_update_op logger.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file diff --git a/mx_rec/optimizers/emb_optimizer.py b/mx_rec/optimizers/emb_optimizer.py deleted file mode 100644 index c7f1b64ab00700609a3796f28cd38c94009f3431..0000000000000000000000000000000000000000 --- a/mx_rec/optimizers/emb_optimizer.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - -from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.tf_version_adapter import NPULossScaleOptimizer - - -class EmbOptimizer: - """ - 稀疏表的优化器. - """ - - def __init__(self, optimizer_list): - self._optimizer_instance_list = optimizer_list - self._optimizer_slot_info_list = [] - self._optimizer = dict() - - @property - def optimizer_instance_list(self): - return self._optimizer_instance_list - - @property - def optimizer_slot_info_list(self): - return self._optimizer_slot_info_list - - @property - def optimizer(self): - return self._optimizer - - @staticmethod - def set_optimizer_slot(slot_info: dict): - """ - 设置稀疏表优化器的slot信息. - - Args: - slot_info: 优化器slot信息 - - Returns: None - """ - slot = slot_info.get("slot") - slot_name = slot_info.get("slot_name") - optimizer = slot_info.get("optimizer") - named_slot_key = slot_info.get("named_slot_key") - - optimizer.insert_slot(slot, named_slot_key, slot_name) - - def set_optimizer(self, key: str, state_dict: dict, table_name: str): - """ - 设置optimizer state. - - Args: - key: 优化器名字 - state_dict: optimizer state - table_name: 稀疏表名 - - Returns: None - """ - if key in self._optimizer: - raise ValueError(f"Optimizer {key} has been set for hash table {table_name}.") - self._optimizer[key] = state_dict - - def check_optimizer_instance_list(self): - """ - 校验优化器实例列表. - """ - if not self._optimizer_instance_list: - raise ValueError("External storage mode should config optimizers before instantiating sparse table, " - "but nothing was configured.") - - for optimizer_instance in self._optimizer_instance_list: - if isinstance(optimizer_instance, NPULossScaleOptimizer): - optimizer_instance = getattr(optimizer_instance, '_opt') - - if not isinstance(optimizer_instance, CustomizedOptimizer): - raise TypeError("The optimizer instance must be an instance of CustomizedOptimizer.") diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 5c68b9292a6b11f6ca7d5012097930a0424bffba..ad4f98801748493de101101777a73b75f15cebf2 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -29,9 +29,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.training import ftrl -from tensorflow.python.training import slot_creator -from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.optimizers.base import CustomizedOptimizer, control_update_op_decorator from mx_rec.util.initialize import ConfigInitializer from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, \ @@ -80,34 +79,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): l2_shrinkage_regularization_strength=kwargs.get("l2_shrinkage_regularization_strength", 0.0) ) self._slot_num = 2 - - @property - def slot_num(self): - return self._slot_num - - def initialize_slots(self, var, table_instance): - val = constant_op.constant( - self._initial_accumulator_value, dtype=var.dtype, shape=var.get_shape()) - - accum = slot_creator.create_slot(var, val, self._name + "/" + "accum") - linear = slot_creator.create_zeros_slot(var, self._name + "/" + "linear") - ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) - ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) - named_slot_key = (var.op.graph, var.op.name) - table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) - ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, - self.optimizer_type, - {"accum": accum, "linear": linear}) - return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, - {"slot": linear, "named_slot_key": named_slot_key, "slot_name": "linear", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot + self._derivative = 2 def get_slot_init_values(self): # return state value list of ftrl that needs to initialize in ASC DDR. @@ -115,10 +87,18 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): return [self._initial_accumulator_value, initial_linear_value] def _apply_sparse_duplicate_indices(self, grad, var): - return self._apply_sparse(grad, var) + # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may + # introduce dynamic shape problem, if encounter that, please de-annotation the method below. + unique_local_grad, unique_keys = self.sum_same_id_gradients(grad=grad.values, var=var, is_expansion=False) + gradient_no_duplicate_indices = ops.IndexedSlices( + indices=unique_keys, + values=unique_local_grad, + dense_shape=grad.dense_shape) + return self._apply_sparse(gradient_no_duplicate_indices, var) def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): - return self._resource_apply_sparse(grad, handle, indices) + unique_local_grad, unique_keys = self.sum_same_id_gradients(grad=grad, var=handle, is_expansion=False) + return self._resource_apply_sparse(unique_local_grad, handle, unique_keys) def _resource_apply_sparse(self, grad, handle, indices): if self._l2_shrinkage_regularization_strength <= 0.0: @@ -148,6 +128,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): grad.indices, lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) + @control_update_op_decorator def _apply_sparse_shared(self, grad, var, indices, scatter_nd_update): accum = self.get_slot(var, "accum") linear = self.get_slot(var, "linear") @@ -189,6 +170,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): return control_flow_ops.group(accum_update_op, linear_update_op, var_update_op) + @control_update_op_decorator def _apply_sparse_shared_v2(self, grad, var, indices, scatter_nd_update): accum = self.get_slot(var, "accum") linear = self.get_slot(var, "linear") diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 6881d6ad7e6c3ae2da5d61b18932366d20c2dc30..89d67d8965e2787372a62f0b7c6a9d47425df99d 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -55,13 +55,7 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo super(CustomizedGradientDescent, self).__init__(learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name) self._slot_num = 0 - - @property - def slot_num(self): - return self._slot_num - - def initialize_slots(self, var, table_instance): - return [] + self._derivative = 1 def get_slot_init_values(self): return [] diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 22b33852ce1669f7e57e9fc6f85547c5550e27a9..8cf9257e5c5ad4d115f5cc4e3166997b2a165c47 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -60,13 +60,7 @@ class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, name=self.unique_name) self._slot_num = 0 - - @property - def slot_num(self): - return self._slot_num - - def initialize_slots(self, var, table_instance): - return [] + self._derivative = 1 def get_slot_init_values(self): return [] diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index d79b6d2393f1e921830a2d4384378d79f6ff5105..ac88afc9d8b41d1ad944c173b1f0ebd100ef6228 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -28,11 +28,11 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import adam -from tensorflow.python.training import slot_creator -from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.optimizers.base import CustomizedOptimizer, control_update_op_decorator from mx_rec.util.initialize import ConfigInitializer -from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator +from mx_rec.util.ops import import_host_pipeline_ops +from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator, ClassValidator @para_checker_decorator(check_option_list=[ @@ -40,9 +40,11 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("beta1", FloatValidator, {"min_value": 0.0, "max_value": 1.0}, ["check_value_for_open_interval"]), ("beta2", FloatValidator, {"min_value": 0.0, "max_value": 1.0}, ["check_value"]), ("epsilon", FloatValidator, {"min_value": 0.0, "max_value": 1.0}, ["check_value_for_left_open_interval"]), - ("name", StringValidator, {"min_len": 1, "max_len": 200}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 200}, ["check_string_length"]), + ("use_fusion_optim", ClassValidator, {"classes": (bool,)}), ]) -def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): +def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam", + use_fusion_optim=False): """ Args: learning_rate: learning rate @@ -50,13 +52,14 @@ def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1 beta2: epsilon: name: - + use_fusion_optim: if use fused optimizer Returns: a customized optimizer instance """ if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - optimizer = CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) + optimizer = CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name, + use_fusion_optim=use_fusion_optim) ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer return optimizer @@ -64,46 +67,21 @@ def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1 class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): name_counter = defaultdict(int) - def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdam"): + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdam", + use_fusion_optim=False): self.optimizer_type = "LazyAdam" self.optim_param_list = ["momentum", "velocity"] self.config_instance = ConfigInitializer.get_instance() + self.use_fusion_optim = use_fusion_optim + if self.use_fusion_optim: + self._custom_initial_beta1 = beta1 + self._custom_initial_beta2 = beta2 + self._custom_initial_epsilon = epsilon super(CustomizedLazyAdam, self)._get_name(name=name) super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, name=self.unique_name) self._slot_num = 2 - - @property - def slot_num(self): - return self._slot_num - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - return new_slot_variable - - momentum = creat_one_single_slot(var, self._name + "/" + "momentum") - velocity = creat_one_single_slot(var, self._name + "/" + "velocity") - self.config_instance.sparse_embed_config.insert_removing_var_list(momentum.name) - self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) - named_slot_key = (var.op.graph, var.op.name) - table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) - ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, - self.optimizer_type, - {"momentum": momentum, - "velocity": velocity}) - return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, - {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot + self._derivative = 2 def get_slot_init_values(self): # return state value list of adam that needs to initialize in ASC DDR. @@ -114,10 +92,16 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): def _apply_sparse_duplicate_indices(self, grad, var): # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may # introduce dynamic shape problem, if encounter that, please de-annotation the method below. - return self._apply_sparse(grad, var) + unique_local_grad, unique_keys = self.sum_same_id_gradients(grad=grad.values, var=var, is_expansion=False) + gradient_no_duplicate_indices = ops.IndexedSlices( + indices=unique_keys, + values=unique_local_grad, + dense_shape=grad.dense_shape) + return self._apply_sparse(gradient_no_duplicate_indices, var) def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): - return self._resource_apply_sparse(grad, handle, indices) + unique_local_grad, unique_keys = self.sum_same_id_gradients(grad=grad, var=handle, is_expansion=False) + return self._resource_apply_sparse(unique_local_grad, handle, unique_keys) def _apply_dense(self, grad, var): raise NotImplementedError("You are using a wrong type of variable.") @@ -136,6 +120,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): } return temp + @control_update_op_decorator def _resource_apply_sparse(self, grad, handle, indices): return self._apply_sparse_shared( grad, @@ -143,6 +128,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): indices, self._resource_scatter_nd_add) + @control_update_op_decorator def _apply_sparse(self, grad, var): return self._apply_sparse_shared( grad.values, @@ -161,6 +147,16 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): temp_epsilon = temp.get("temp_epsilon") learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) + if self.use_fusion_optim: + nd_indices = tf.expand_dims(indices, 1) + slot_m = self.get_slot(var, "m") + slot_v = self.get_slot(var, "v") + output_m, output_v, output_var = \ + import_host_pipeline_ops().lazy_adam(grad, nd_indices, slot_m, slot_v, var, learning_rate, + self._custom_initial_beta1, self._custom_initial_beta2, + self._custom_initial_epsilon) + return control_flow_ops.group(output_m, output_v, output_var) + abs_indices = tf.math.maximum(indices, 0) nd_indices = tf.expand_dims(indices, 1) @@ -174,7 +170,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) v_update_op = scatter_nd_add(velocity, nd_indices, v_t_slice - old_v_slice) - denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon + denominator_slice = math_ops.sqrt(tf.abs(v_t_slice)) + temp_epsilon var_update_op = scatter_nd_add(var, nd_indices, tf.divide(-learning_rate * m_t_slice, denominator_slice)) return control_flow_ops.group(m_update_op, v_update_op, var_update_op) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 9225282439a155605fdb215754360b02b340b5ce..1d5aacd2ceb33d5d115e235d1c0990c18f32e701 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -73,10 +73,7 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): name=self.unique_name) self._slot_num = 2 - - @property - def slot_num(self): - return self._slot_num + self._derivative = 2 def get_slot_init_values(self): # return state value list of adam that needs to initialize in ASC DDR. @@ -109,9 +106,10 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): return temp def _apply_sparse(self, grad, addr): + unique_local_grad, unique_addr = self.sum_same_id_gradients(grad=grad, var=addr, is_expansion=True) return self._apply_sparse_shared( - grad, - addr) + unique_local_grad, + unique_addr) def _apply_sparse_shared(self, grad, addr): power_b1, power_b2 = self._get_beta_accumulators() @@ -138,7 +136,7 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): old_v_slice = split_tensors[2] v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) - denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon + denominator_slice = math_ops.sqrt(tf.abs(v_t_slice)) + temp_epsilon update_list = [tf.divide(-learning_rate * m_t_slice, denominator_slice)] + [m_t_slice - old_m_slice] + \ [v_t_slice - old_v_slice] update_tensor = tf.concat(update_list, axis=1) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index fcf1134f565c8e89d4e9844d28143acc2ef7b06a..f57e8ce0edfe7496b51ea95f927f87ac34840f4d 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -23,6 +23,7 @@ import os import time import tensorflow as tf +from tensorflow.compat.v1.summary import FileWriter from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import trackable_object_graph_pb2 from tensorflow.python import pywrap_tensorflow @@ -30,6 +31,7 @@ from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import graph_io from tensorflow.python.ops import variables from tensorflow.python.ops import io_ops from tensorflow.python.platform import gfile @@ -41,14 +43,18 @@ from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util import numpy as np +from mpi4py import MPI -from mx_rec.saver.saver import Saver as SparseSaver, check_file_system_is_valid +from mx_rec.saver.saver import Saver as SparseSaver, check_file_system_is_valid, should_write_data +from mx_rec.util.communication.hccl_ops import get_rank_id from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, OptionalIntValidator, \ OptionalStringValidator, DirectoryValidator from mx_rec.util.log import logger from mx_rec.constants.constants import MAX_INT32, INVALID_CHARS +_FILENAME_SUFFIX = "filename_suffix" + def get_sparse_vars(var_list): sparse_var_list = [] @@ -248,11 +254,10 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra self.sparse_saver.save(sess, save_path=checkpoint_file) logger.info("Save sparse model into dir %s", checkpoint_file) - from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() comm.Barrier() - if rank == 0: + if should_write_data(rank, save_path): model_checkpoint_path = compat.as_str(get_model_checkpoint_path(self, checkpoint_file, sess)) if write_state: update_checkpoint_state(self, model_checkpoint_path, save_path_parent, latest_filename, meta_graph_suffix, @@ -289,7 +294,7 @@ def restore(self, sess, save_path): if self._is_empty: return if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): - raise ValueError("The passed save_path is not a valid checkpoint: " + + raise ValueError("the passed save_path is not a valid checkpoint: " + checkpoint_prefix) tf_logging.info("Restoring parameters from %s", checkpoint_prefix) @@ -447,6 +452,19 @@ class BulkSaverBuilder(BaseSaverBuilder): return io_ops.restore_v2(filename_tensor, tensor_names, tensor_slices, tensor_dtypes) +def patch_for_write_graph_func(func): + def wrapper(*args, **kwargs): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + # In the case of multiple processes, choose one process to write graph. + if len(args) > 1 and should_write_data(rank, args[1]): + return func(*args, **kwargs) + else: + return None + + return wrapper + + def patch_for_saver(): dense_saver = tf.compat.v1.train.Saver dense_saver.__init__ = saver_init @@ -454,3 +472,25 @@ def patch_for_saver(): dense_saver.restore = restore dense_saver.build = build logger.debug("Class tf.train.Saver has been patched.") + training_util.write_graph = patch_for_write_graph_func(graph_io.write_graph) + + +def _patch_for_summary_writer(func): + def wrapper(*args, **kwargs): + filename_suffix = kwargs.get(_FILENAME_SUFFIX, "") + filename_suffix = filename_suffix or "" + rank_suffix = "_rank" + str(get_rank_id()) + if rank_suffix not in filename_suffix: + filename_suffix = rank_suffix + "_" + filename_suffix if filename_suffix else rank_suffix + kwargs[_FILENAME_SUFFIX] = filename_suffix + return func(*args, **kwargs) + + return wrapper + + +def patch_for_summary_writer(): + """ + Patch for `tf.summary.FileWriter.__init__` method, add rankId to init param `filename_suffix`. + """ + FileWriter.__init__ = _patch_for_summary_writer(FileWriter.__init__) + logger.debug("Method `tf.summary.FileWriter.__init__` has been patched.") diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d776b699379e379047c3b14cd0de3281c21d1a9a..a6362506adca393c597c3984653aee9eecdaaa0b 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -24,7 +24,7 @@ import tensorflow as tf from tensorflow.python.util import compat from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice, \ - MAX_INT32, HDFS_FILE_PREFIX + MAX_INT32, HDFS_FILE_PREFIX, TRAIN_CHANNEL_ID from mx_rec.util.communication.hccl_ops import get_rank_id, get_rank_size, get_local_rank_size from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.perf import performance @@ -33,6 +33,9 @@ from mx_rec.validator.validator import DirectoryValidator, FileValidator, para_c from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.tf_version_adapter import npu_ops + +SAVE_SPARSE_PATH_PREFIX = "sparse" # define save model thread @@ -50,36 +53,38 @@ class SaveModelThread(threading.Thread): class Saver(object): - @staticmethod - def _make_table_name_dir(root_dir, table_instance, table_name): - if not table_instance.is_hbm: - table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) - else: - table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) - try: - tf.io.gfile.makedirs(table_dir) - except Exception as err: - raise RuntimeError(f"make dir {table_dir} for saving sparse table failed!") from err - @para_checker_decorator(check_option_list=[ ("var_list", ClassValidator, {"classes": (list, type(None))}), ("max_to_keep", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), ("prefix_name", ClassValidator, {"classes": (str, type(None))}), ("prefix_name", OptionalStringValidator, {"min_len": 1, "max_len": 50}, ["check_string_length"]), ]) - def __init__(self, var_list=None, max_to_keep=3, prefix_name="checkpoint"): + def __init__(self, var_list=None, max_to_keep=3, prefix_name="checkpoint", warm_start_tables=None): self.max_to_keep = max_to_keep self._prefix_name = prefix_name self.var_list = var_list self.rank_id = get_rank_id() self.local_rank_size = get_local_rank_size() self.local_rank_id = self.rank_id % self.local_rank_size + self.rank_size = get_rank_size() self.save_op_dict = defaultdict(dict) - self.restore_fetch_list = [] + self.restore_fetch_dict = defaultdict() self.placeholder_dict = defaultdict(dict) self._last_checkponts = [] self.config_instance = ConfigInitializer.get_instance() self.build() + self.warm_start_tables = warm_start_tables + + @staticmethod + def _make_table_name_dir(root_dir, table_instance, table_name): + if not table_instance.is_hbm: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + try: + tf.io.gfile.makedirs(table_dir) + except Exception as err: + raise RuntimeError(f"make dir {table_dir} for saving sparse table failed!") from err def build(self): if self.var_list is None: @@ -125,9 +130,9 @@ class Saver(object): if global_step: if not isinstance(global_step, compat.integral_types): global_step = int(sess.run(global_step)) - ckpt_name = f"sparse-{base_name}-{global_step}" + ckpt_name = f"{SAVE_SPARSE_PATH_PREFIX}-{base_name}-{global_step}" else: - ckpt_name = f"sparse-{base_name}" + ckpt_name = f"{SAVE_SPARSE_PATH_PREFIX}-{base_name}" saving_path = os.path.join(directory, ckpt_name) self.config_instance.train_params_config.sparse_dir = saving_path @@ -165,7 +170,7 @@ class Saver(object): comm = MPI.COMM_WORLD rank = comm.Get_rank() comm.Barrier() - if rank == 0: + if should_write_data(rank, saving_path): table_list = self.save_op_dict.keys() for table_name in table_list: self.merge_sparse_file(saving_path, table_name) @@ -175,21 +180,20 @@ class Saver(object): logger.info("======== Saving finished for rank id %s ========", self.rank_id) @performance("Restore") - def restore(self, sess, reading_path): + def restore(self, sess, reading_path, warm_start_tables=None): logger.debug("======== Start restoring ========") if not check_file_system_is_valid(reading_path): raise ValueError("the path to save sparse embedding table data belong to invalid file system, " "only local file system and hdfs file system supported. ") directory, base_name = os.path.split(reading_path) - ckpt_name = f"sparse-{base_name}" + ckpt_name = f"{SAVE_SPARSE_PATH_PREFIX}-{base_name}" reading_path = os.path.join(directory, ckpt_name) - self.config_instance.train_params_config.sparse_dir = reading_path if not tf.io.gfile.exists(reading_path): raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") - self._restore(sess, reading_path) + self._restore(sess, reading_path, warm_start_tables) logger.info("sparse model was restored from dir '%s' .", reading_path) logger.debug("======== Restoring finished ========") @@ -233,7 +237,21 @@ class Saver(object): attribute = attribute.astype(np.int64) attribute_dir = os.path.join(upper_dir, "slice.attribute") - attribute.tofile(attribute_dir) + with tf.io.gfile.GFile(attribute_dir, "wb") as file: + attribute = attribute.tostring() + file.write(attribute) + + def get_warm_start_dict(self, table_list): + placeholder_dict = defaultdict(dict) + restore_fetch_list = [] + for table_name, v in self.placeholder_dict.items(): + if table_name in table_list: + placeholder_dict[table_name] = v + restore_fetch_list.append(self.restore_fetch_dict.get(table_name)) + + if not restore_fetch_list: + logger.warning("no tables can be warm start restored.") + return placeholder_dict, restore_fetch_list @performance("_save") def _save(self, sess, root_dir): @@ -242,10 +260,15 @@ class Saver(object): if optimizer_instance: set_optimizer_info(optimizer_instance, table_name) - if self.config_instance.hybrid_manager_config.asc_manager: - self.config_instance.hybrid_manager_config.save_host_data(root_dir) - logger.debug(f"host data was saved.") + table_instance0 = self.config_instance.sparse_embed_config.get_table_instance(self.var_list[0]) + if table_instance0.is_hbm: + self._save_hbm(sess, root_dir) + else: + self._save_ddr(sess, root_dir) + logger.debug(f"Host data was saved.") + def _save_hbm(self, sess, root_dir): + self.config_instance.hybrid_manager_config.save_host_data(root_dir) if self.config_instance.use_dynamic_expansion: # Data related to dynamic expansion needs to be saved only on the host side. return @@ -262,6 +285,42 @@ class Saver(object): for thread in threads: thread.join() + def _save_ddr(self, sess, root_dir): + # 接受host侧传来的需要swap_out的offset用于更新host侧并保存 + self.config_instance.hybrid_manager_config.fetch_device_emb() + # In DDR mode, within the save process, the graph has been fixed and cannot execute the get_next op. + # The _unsafe_unfinalize operation can modify the state of the graph being fixed. + sess.graph._unsafe_unfinalize() + for var in self.var_list: + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) + table_name = table_instance.table_name + + use_static = ConfigInitializer.get_instance().use_static + max_lookup_vec_size = None + if use_static: + max_lookup_vec_size = table_instance.send_count * self.rank_size + swap_out_pos, swap_out_len = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[[max_lookup_vec_size], []], + channel_name=f'{table_name}_save_h2d_{TRAIN_CHANNEL_ID}') + if use_static: + swap_out_pos = swap_out_pos[:swap_out_len] + + table = [var] + optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(table_name) + if optimizer is not None: + for slots in optimizer.values(): + table += list(slots.values()) + + swap_outs = [tf.gather(one_table, swap_out_pos) for one_table in table] + swap_out = tf.concat(swap_outs, axis=1) + channel_name = f'{table_name}_save_d2h_{TRAIN_CHANNEL_ID}' + logger.debug('channel %s was built for op swap_out_op.', channel_name) + swap_out_op = npu_ops.outfeed_enqueue_op(channel_name=channel_name, inputs=[swap_out]) + # 发送host需要的embedding + sess.run(swap_out_op) + self.config_instance.hybrid_manager_config.save_host_data(root_dir) + def _get_valid_dict_data(self, dump_data_dict, table_name): host_data = self.config_instance.hybrid_manager_config.get_host_data(table_name) offset = list(host_data) @@ -294,7 +353,7 @@ class Saver(object): table_instance.emb_size], name=DataName.EMBEDDING.value) assign_op = var.assign(variable) - self.restore_fetch_list.append(assign_op) + self.restore_fetch_dict[table_instance.table_name] = [assign_op] optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name( table_instance.table_name) if optimizer: @@ -313,25 +372,35 @@ class Saver(object): if sub_optimizer_placeholder_dict.get(key_state).graph is not state.graph: continue assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) - self.restore_fetch_list.append(assign_op) + self.restore_fetch_dict[table_instance.table_name].append(assign_op) - def _restore(self, sess, reading_path): - for table_name in self.placeholder_dict: + def _restore(self, sess, reading_path, warm_start_tables=None): + # 根据table_list去改造 + if warm_start_tables: + placeholder_dict, restore_fetch_list = self.get_warm_start_dict(warm_start_tables) + else: + placeholder_dict, restore_fetch_list = self.placeholder_dict, self.restore_fetch_dict + + for table_name in placeholder_dict: optimizer_instance = ConfigInitializer.get_instance().optimizer_config.optimizer_instance if optimizer_instance: set_optimizer_info(optimizer_instance, table_name) if self.config_instance.hybrid_manager_config.asc_manager: - self.config_instance.hybrid_manager_config.restore_host_data(reading_path) + self.config_instance.hybrid_manager_config.restore_host_data(reading_path, warm_start_tables) logger.info("host data was restored.") + table_instance0 = self.config_instance.sparse_embed_config.get_table_instance(self.var_list[0]) + if not table_instance0.is_hbm: + return + if self.config_instance.use_dynamic_expansion: # Data related to dynamic expansion needs to be restored only on the host side. return restore_feed_dict = defaultdict(dict) - for table_name, sub_placeholder_dict in self.placeholder_dict.items(): + for table_name, sub_placeholder_dict in placeholder_dict.items(): load_offset = self.config_instance.hybrid_manager_config.get_load_offset(table_name) fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, NameDescriptor(table_name, DataName.EMBEDDING.value), load_offset) @@ -341,7 +410,7 @@ class Saver(object): _fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, restore_feed_dict, table_name, load_offset) - sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) + sess.run(restore_fetch_list, feed_dict=restore_feed_dict) class NameDescriptor: @@ -393,7 +462,7 @@ def save_embedding_data(root_dir, table_name, dump_data_dict, suffix): attribute = dict() attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name attribute[DataAttr.SHAPE.value] = data_to_write.shape - write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + write_binary_data(target_path, suffix, data_to_write) def save_feature_mapping_data(root_dir, table_name, dump_data_dict, suffix): @@ -405,7 +474,7 @@ def save_feature_mapping_data(root_dir, table_name, dump_data_dict, suffix): attribute = dict() attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name attribute[DataName.THRESHOLD.value] = int(dump_data_dict.get(DataName.THRESHOLD.value)) - write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + write_binary_data(target_path, suffix, data_to_write) def save_offset_data(root_dir, table_name, dump_data_dict, suffix): @@ -416,7 +485,7 @@ def save_offset_data(root_dir, table_name, dump_data_dict, suffix): attribute = dict() attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name - write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + write_binary_data(target_path, suffix, data_to_write) def save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, suffix): @@ -427,7 +496,7 @@ def save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimiz attribute = dict() attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name attribute[DataAttr.SHAPE.value] = data_to_write.shape - write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + write_binary_data(target_path, suffix, data_to_write) def generate_path(*args): @@ -438,15 +507,16 @@ def generate_file_name(suffix): return "slice_%d.data" % suffix, "slice_%d.attribute" % suffix -def write_binary_data(writing_path, suffix, data, attributes=None): +def write_binary_data(writing_path: str, suffix: int, data: np.ndarray): try: tf.io.gfile.makedirs(writing_path) except Exception as err: raise RuntimeError(f"make dir {writing_path} for writing data failed!") from err data_file, attribute_file = generate_file_name(suffix) target_data_dir = os.path.join(writing_path, data_file) - - with tf.io.gfile.GFile(target_data_dir, "ab") as file: + # append mode of hdfs system supports not well when the file not exists. + file_mode = "wb" if not tf.io.gfile.exists(target_data_dir) else "ab" + with tf.io.gfile.GFile(target_data_dir, file_mode) as file: data = data.tostring() file.write(data) @@ -470,7 +540,11 @@ def read_binary_data(reading_path: str, data_name: str, table_name: str, load_of with tf.io.gfile.GFile(target_attribute_dir, "rb") as fin: validate_read_file(target_attribute_dir) - attributes = np.fromfile(target_attribute_dir, dtype=np.int64) + attributes = fin.read() + try: + attributes = np.fromstring(attributes, dtype=np.int64) + except ValueError as err: + raise RuntimeError(f"get attributes from file {target_attribute_dir} failed.") from err with tf.io.gfile.GFile(target_data_dir, "rb") as file: validate_read_file(target_data_dir) @@ -622,4 +696,15 @@ def set_optimizer_info(optimizer: CustomizedOptimizer, table_name: str): """ from mxrec_pybind import OptimizerInfo optim_info = OptimizerInfo(optimizer.optimizer_type, optimizer.optim_param_list) - ConfigInitializer.get_instance().hybrid_manager_config.set_optim_info(table_name, optim_info) \ No newline at end of file + ConfigInitializer.get_instance().hybrid_manager_config.set_optim_info(table_name, optim_info) + + +def should_write_data(rank_id: int, save_path: str) -> bool: + # When using hdfs filesystem, only the rank0 process execute write data operation, assuming use same hdfs path in + # multi-machine. + # When using local filesystem, the process which `rank_id % local_rank_size == 0` execute write data operation. + # When using hdfs filesystem, and use different hdfs path to save data, should modify check condition + # as same as local filesystem. + is_hdfs = check_file_system_is_hdfs(save_path) + local_rank_size = get_local_rank_size() + return rank_id == 0 if is_hdfs else rank_id % local_rank_size == 0 diff --git a/mx_rec/saver/warm_start.py b/mx_rec/saver/warm_start.py new file mode 100644 index 0000000000000000000000000000000000000000..7ceb14c1e792c5a25896b9a07b26d98d15816584 --- /dev/null +++ b/mx_rec/saver/warm_start.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import logging +import re +from typing import List +import six + +import tensorflow as tf +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.training import warm_starting_util + +from mx_rec.util.log import logger +from mx_rec.saver.saver import Saver + + +class WarmStartController: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(WarmStartController, cls).__new__(cls) + cls._instance._warm_start_dict = {} + cls._instance.table_name_to_prev_table_name = {} + return cls._instance + + def __init__(self): + logging.info("start to build WarmStartController.") + + def add_element(self, path: str, table_list: List[str]): + if path not in self._warm_start_dict: + self._warm_start_dict[path] = table_list + else: + self._warm_start_dict[path] += table_list + + def add_table_to_prev_table(self, table: str, prev_table: str): + self.table_name_to_prev_table_name[table] = prev_table + + def get_elements(self): + return self._warm_start_dict + + +def patch_for_warm_start(): + estimator_lib.Estimator.__init__ = patch_estimator_init(estimator_lib.Estimator.__init__) + warm_starting_util.warm_start = patch_for_func_warm_start(warm_starting_util.warm_start) + estimator_lib.Estimator.train = patch_for_estimator_train(estimator_lib.Estimator.train) + + +def patch_estimator_init(func): + def wrapper(*args, **kwargs): + warm_start_from = kwargs.get('warm_start_from', None) + if warm_start_from: + kwargs['warm_start_from'] = warm_settings_filter(warm_start_from) + return func(*args, **kwargs) + return wrapper + + +def patch_for_func_warm_start(func): + def wrapper(*args, **kwargs): + ckpt_to_initialize_from = args[0] + if isinstance(ckpt_to_initialize_from, (list, tuple)): + vars_to_warm_start_list = args[1] + var_name_to_prev_var_name_list = args[3] + warm_start_num = len(ckpt_to_initialize_from) + for i in range(warm_start_num): + f = func(ckpt_to_initialize_from[i], vars_to_warm_start_list[i], args[2], + var_name_to_prev_var_name_list[i], **kwargs) + return f + else: + return func(*args, **kwargs) + return wrapper + + +def patch_for_estimator_train(func): + def wrapper(*args, **kwargs): + hooks = kwargs.get('hooks', []) + if WarmStartController().get_elements(): + hooks.append(SparseRestoreHook()) + return func(*args, **kwargs) + return wrapper + + +def warm_settings_filter(warm_start_from): + warm_start_from_res = None + if isinstance(warm_start_from, estimator_lib.WarmStartSettings): + if isinstance(warm_start_from.ckpt_to_initialize_from, (list, tuple)): + out_setting_list = [] + logger.info("According to warm_start_settings, warm start will load from more than one checkpoint path.") + warm_start_settings_list = _build_warm_settings_list(warm_start_from) + for setting in warm_start_settings_list: + filter_setting = _warm_settings_filter(setting) + if filter_setting: + out_setting_list.append(filter_setting) + if out_setting_list: + warm_start_from_res = recover_warm_settings(out_setting_list) + elif isinstance(warm_start_from.ckpt_to_initialize_from, (six.string_types, six.binary_type)): + logger.info("According to warm_start_settings, warm start will load from only one checkpoint path.") + filter_setting = _warm_settings_filter(warm_start_from) + if filter_setting: + warm_start_from_res = filter_setting + elif isinstance(warm_start_from, (six.string_types, six.binary_type)): + table_name_list = get_table_name_set_by_ckpt_path(warm_start_from) + WarmStartController().add_element(warm_start_from, table_name_list) + warm_start_from_res = warm_start_from + else: + raise ValueError("Invalid parameter: warm_start_from. ") + return warm_start_from_res + + +def recover_warm_settings(setting_list: List[tf.estimator.WarmStartSettings]) -> tf.estimator.WarmStartSettings: + """ + Recover WarmStartSettings from a list of custom-defined WarmStartSettings. + """ + ckpt_to_initialize_from_list = [] + vars_to_warm_start_list = [] + var_name_to_prev_var_name_list = [] + for setting in setting_list: + ckpt_to_initialize_from_list.append(setting.ckpt_to_initialize_from) + vars_to_warm_start_list.append(setting.vars_to_warm_start) + var_name_to_prev_var_name_list.append(setting.var_name_to_prev_var_name) + + return estimator_lib.WarmStartSettings( + ckpt_to_initialize_from=ckpt_to_initialize_from_list, + vars_to_warm_start=vars_to_warm_start_list, + var_name_to_prev_var_name=var_name_to_prev_var_name_list) + + +def _build_warm_settings_list(warm_start_from: tf.estimator.WarmStartSettings) -> List[tf.estimator.WarmStartSettings]: + """ + Converts custom-defined WarmStartSettings into a list of TensorFlow-native WarmStartSettings. + """ + ckpt_to_initialize_from = warm_start_from.ckpt_to_initialize_from + vars_to_warm_start = warm_start_from.vars_to_warm_start + var_name_to_prev_var_name = warm_start_from.var_name_to_prev_var_name + # 类型校验 + for params in [vars_to_warm_start, var_name_to_prev_var_name]: + if not isinstance(params, (list, tuple)): + raise ValueError("If you choose to load from multiple model paths through the warm start option, " + "then the parameter type in the warm settings should be a list.") + # 长度校验 + if not (len(ckpt_to_initialize_from) == len(vars_to_warm_start) == len(var_name_to_prev_var_name)): + raise ValueError("If you choose to load from multiple model paths through the warm start option, " + "then the parameter list list should be the same length. ") + warm_start_settings_count = len(ckpt_to_initialize_from) + + warm_start_settings_list = [] + for i in range(warm_start_settings_count): + tmp_settings = estimator_lib.WarmStartSettings( + ckpt_to_initialize_from=ckpt_to_initialize_from[i], + vars_to_warm_start=vars_to_warm_start[i], + var_name_to_prev_var_name=var_name_to_prev_var_name[i]) + warm_start_settings_list.append(tmp_settings) + return warm_start_settings_list + + +def _warm_settings_filter(warm_start_setting: tf.estimator.WarmStartSettings) -> tf.estimator.WarmStartSettings: + """ + Filter the vars_to_warm_start parameter to remove sparse table parameters. + """ + vars_to_warm_start = warm_start_setting.vars_to_warm_start + var_name_to_prev_var_name = warm_start_setting.var_name_to_prev_var_name + vars_to_warm_start_res = [] + warm_start_setting_res = None + table_name_list = get_table_name_set_by_ckpt_path(warm_start_setting.ckpt_to_initialize_from) + if isinstance(vars_to_warm_start, str): + matching_tables = [table for table in table_name_list if re.match(vars_to_warm_start, table)] + if matching_tables: + WarmStartController().add_element(warm_start_setting.ckpt_to_initialize_from, matching_tables) + warm_start_setting_res = warm_start_setting + elif all(isinstance(v, str) for v in vars_to_warm_start): + sparse_vars = [] + for v in vars_to_warm_start: + matching_tables = [table for table in table_name_list if re.match(v, table)] + if matching_tables: + sparse_vars.append(v) + WarmStartController().add_element(warm_start_setting.ckpt_to_initialize_from, matching_tables) + vars_to_warm_start_res = [v for v in vars_to_warm_start if v not in sparse_vars] + if vars_to_warm_start_res: + warm_start_setting_res = estimator_lib.WarmStartSettings( + ckpt_to_initialize_from=warm_start_setting.ckpt_to_initialize_from, + vars_to_warm_start=vars_to_warm_start_res, + var_name_to_prev_var_name=warm_start_setting.var_name_to_prev_var_name) + else: + raise ValueError("vars_to_warm_start must be list or str!") + return warm_start_setting_res + + +def get_table_name_set_by_ckpt_path(warm_start_path: str) -> List[str]: + ''' + Get the list of sparse table names saved under the path 'warm_start_path'. + ''' + table_name_list = [] + if tf.io.gfile.isdir(warm_start_path): + restore_path = get_latest_ckpt(warm_start_path) + else: + restore_path = warm_start_path + directory, base_name = os.path.split(restore_path) + ckpt_name = f"sparse-{base_name}" + sparse_path = os.path.join(directory, ckpt_name) + if not tf.io.gfile.isdir(sparse_path): + logger.info("under the warm start path %s, sparse directory %s not exists.", warm_start_path, sparse_path) + else: + for dirname in tf.io.gfile.listdir(sparse_path): + table_name_list.append(dirname) + return table_name_list + + +def get_latest_ckpt(warm_start_path: str) -> str: + ckpt_path = os.path.join(warm_start_path, "checkpoint") + if not tf.io.gfile.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint file is missing under the warm start model path {warm_start_path}") + with tf.io.gfile.GFile(ckpt_path, "r") as f: + latest_ckpt = f.readline().rstrip() + latest_ckpt = latest_ckpt.split(":")[1].strip(' ').replace('"', '') + latest_ckpt = latest_ckpt.split("/")[-1] + path = os.path.join(warm_start_path, latest_ckpt) + return path + + +class SparseRestoreHook(tf.estimator.SessionRunHook): + def __init__(self): + logging.info("In warm start mode, SparseRestoreHook has been initialized.") + self._is_warm_start = False + self._saver = None + self._warm_start_dict = {} + + def begin(self): + self._saver = Saver() + logging.info("In warm start mode, begin SparseRestoreHook.") + + def after_create_session(self, session, coord): + if not self._is_warm_start: + self._warm_start_dict = WarmStartController().get_elements() + for path, restore_tables in self._warm_start_dict.items(): + restore_path = get_latest_ckpt(path) + self._saver.restore(session, restore_path, restore_tables) + self._is_warm_start = True diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 6eb5a70f83b9920d917a09c6f8cc981747efdd52..43042d6b5662d86fb4b53b9e7672114a34bed9ec 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -16,19 +16,20 @@ # ============================================================================== import json -import os -import re +from typing import Dict, List -from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID, \ - MIN_RANK_SIZE, MAX_RANK_SIZE -from mx_rec.validator.validator import FileValidator, para_checker_decorator, StringValidator, \ - Convert2intValidator +from mx_rec.constants.constants import MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID +from mx_rec.validator.validator import FileValidator from mx_rec.util.global_env_conf import global_env -def parse_hccl_json(): +def parse_hccl_json() -> Dict[int, int]: + """ + Used for rank table file configured training situation. + :return: rank_id to logic_id mapping dictionary. + """ rank_table_path = global_env.rank_table_file - with open(rank_table_path, "r", encoding="utf-8"): + with open(rank_table_path, "r", encoding="utf-8") as file: # check whether json file is valid file_validator = FileValidator("RANK_TABLE_FILE", rank_table_path) # 1.check whether rank_table_path is soft link @@ -37,14 +38,13 @@ def parse_hccl_json(): file_validator.check_file_size(MAX_CONFIG_SIZE, MIN_SIZE) file_validator.check() - rank_table_path = os.path.realpath(global_env.rank_table_file) - with open(rank_table_path, "r", encoding="utf-8") as file: try: table_hccl = json.load(file) except FileNotFoundError as e: raise ValueError("rank table file not found") from e except json.JSONDecodeError as e: raise ValueError("rank table file is unable to parse as json") from e + if "server_list" not in table_hccl: raise AttributeError(f"Lack of attribute server_list.") if not table_hccl.get("server_list"): @@ -62,76 +62,51 @@ def parse_hccl_json(): if "rank_id" not in device or not device.get("rank_id").isdigit(): raise ValueError(f"hccl_json rank_id wrong.") rank_id = int(device.get("rank_id")) + if "device_id" not in device or not device.get("device_id").isdigit(): raise ValueError(f"hccl_json device_id wrong.") import mxrec_pybind - res = mxrec_pybind.get_logic_id(int(device.get("device_id"))) - if res < 0: - raise RuntimeError( - f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") - if res > MAX_DEVICE_ID: + logic_id = mxrec_pybind.get_logic_id(int(device.get("device_id"))) + if logic_id < 0: + raise RuntimeError(f"get logic id from physic id fail, error code is {logic_id}, " + f"please check if dsmi api is functional.") + if logic_id > MAX_DEVICE_ID: raise ValueError(f"get logic id from physic id fail, the device id is invalid.") - rank_to_device_dict[rank_id] = res - + rank_to_device_dict[rank_id] = logic_id return rank_to_device_dict -def set_hccl_info_without_json() -> dict: +def set_hccl_info_without_json() -> Dict[int, int]: """ Used for no rank table file configured training situation. - :return: device_id and logic_id mapping. + :return: rank_id to logic_id mapping dictionary. """ - visible_devices = global_env.ascend_visible_devices - rank_size = global_env.cm_worker_size - chief_device = global_env.cm_chief_device - device_list = get_device_list(visible_devices) - chief_device = int(chief_device) - rank_size = int(rank_size) - - sorted_device_list = sorted(device_list) + env_rank_size = global_env.cm_worker_size + env_chief_device = global_env.cm_chief_device + device_list = get_device_list() + chief_device = int(env_chief_device) + rank_size = int(env_rank_size) - if chief_device not in sorted_device_list: + if chief_device not in device_list: raise ValueError(f"The environment variable CM_CHIEF_DEVICE {chief_device} is not in the local device list. ") rank_to_device_dict = {} - chief_index = sorted_device_list.index(chief_device) - sorted_device_list = sorted_device_list[chief_index:] + sorted_device_list[0: chief_index] - sorted_device_list = sorted_device_list[:rank_size] - - for device_idx in sorted_device_list: - import mxrec_pybind - res = mxrec_pybind.get_logic_id(int(device_idx)) - if res < 0: - raise RuntimeError( - f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") - - if res > MAX_DEVICE_ID: - raise ValueError(f"get logic id from physic id fail. res: {res}, chief_device: {chief_device}, " - f"device_idx: {device_idx}") - index = sorted_device_list.index(device_idx) - rank_to_device_dict[index] = res + chief_index = device_list.index(chief_device) + device_list = device_list[chief_index:] + device_list[:chief_index] + device_list = device_list[:rank_size] + + for rank_id, device_id in enumerate(device_list): + rank_to_device_dict[rank_id] = device_id return rank_to_device_dict -def get_device_list(ascend_visible_devices): - device_list = [] - try: - nums = re.findall(r'\d+', ascend_visible_devices) - # eg1:4-11, 则nums=['4', '11'] eg2:0-3,8-11 则nums['0', '3', '8', '11'] - if not all(int(i) <= MAX_DEVICE_ID for i in nums): - raise ValueError("invalid env variable ascend_visible_devices.") - ranges = re.findall(r'\d+-\d+', ascend_visible_devices) - # eg1:4-11, 则ranges=['4-11'] eg2:0-3,8-11 则ranges['0-3', '8-11'] - for r in ranges: - start, end = map(int, r.split('-')) # '4-11', 则start 4, end 11. ['0-3', '8-11'] - if start >= end: - raise ValueError("invalid env variable ascend_visible_devices.") - nums.extend(range(start, end + 1)) - device_list = sorted(list(set(map(int, nums)))) - except ValueError as error: - raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured.") from error - - if not device_list: - raise ValueError("No device is available in the environment.") +def get_device_list() -> List[int]: + """ + Obtain the number of visible Ascend devices in the environment. + :return: the logic id list of visible Ascend devices . + """ + import mxrec_pybind + device_count = mxrec_pybind.get_device_count() + device_list = [i for i in range(device_count)] return device_list \ No newline at end of file diff --git a/mx_rec/util/communication/hccl_ops.py b/mx_rec/util/communication/hccl_ops.py index 52fbf74c2367d9de713a25895d09b2c0de16a672..d4ea6136125739da70a9f24961210a5a1ac5d95b 100644 --- a/mx_rec/util/communication/hccl_ops.py +++ b/mx_rec/util/communication/hccl_ops.py @@ -29,9 +29,9 @@ def get_rank_id() -> Optional[int]: def get_device_id() -> Optional[int]: """ - Get the device id of the calling process + Get the device logic id of the calling process Note: this method should be used after mpi init - :return: int or None, the device id of the calling process + :return: int or None, the device logic id of the calling process """ if global_env.rank_table_file: rank_to_device_dict = parse_hccl_json() diff --git a/mx_rec/util/config_utils/embedding_utils.py b/mx_rec/util/config_utils/embedding_utils.py index 68ceef3a4545e04f6e404c09ed1fd68c6af833f5..e13d9d511f313fdb075ea75b54eff314a367de7a 100644 --- a/mx_rec/util/config_utils/embedding_utils.py +++ b/mx_rec/util/config_utils/embedding_utils.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. from typing import Optional +from tensorflow.python.framework import ops from tensorflow import Variable from mx_rec.util.log import logger @@ -18,6 +19,7 @@ class SparseEmbedConfig: self._table_name_set = set() self._removing_var_list = [] self._name_to_var_dict = dict() + self._tensor_to_table_instance_dict = dict() @property def table_instance_dict(self): @@ -45,6 +47,12 @@ class SparseEmbedConfig: return self._table_instance_dict.get(key) + def get_table_instance_by_tensor(self, tensor) -> object: + if tensor not in self._tensor_to_table_instance_dict: + raise KeyError(f"Given tensor does not exist.") + + return self._tensor_to_table_instance_dict.get(tensor) + def get_table_instance_by_name(self, table_name: Optional[str]) -> object: if table_name not in self._name_to_var_dict: raise KeyError(f"Given table name does not exist.") @@ -74,5 +82,11 @@ class SparseEmbedConfig: self._name_to_var_dict[name] = key self._table_instance_dict[key] = instance + def insert_table_instance_to_tensor_dict(self, tensor: ops.Tensor, instance: object) -> None: + if tensor in self._tensor_to_table_instance_dict: + raise KeyError(f"Given tensor {tensor} has been used.") + logger.debug("Record one hash table for expansion mode, with tensor: %s.", tensor) + self._tensor_to_table_instance_dict[tensor] = instance + def export_table_num(self) -> int: return len(self.table_instance_dict) if self.table_instance_dict else 0 diff --git a/mx_rec/util/config_utils/feature_spec_utils.py b/mx_rec/util/config_utils/feature_spec_utils.py index 4c40996c9d39fd2ca68fa02440eb2b816d8b5672..f244bb39ae2d010a629df820abaafb69b84cc24b 100644 --- a/mx_rec/util/config_utils/feature_spec_utils.py +++ b/mx_rec/util/config_utils/feature_spec_utils.py @@ -25,7 +25,7 @@ class FeatureSpecConfig: def clear_same_table_feature_spec(self, table_name: Optional[str], is_training: bool) -> None: if self.table_name_to_feature_spec.get(table_name) is None or \ self.table_name_to_feature_spec.get(table_name).get(is_training) is None: - raise KeyError("The table name `%s` does not exist in table_name_to_feature_spec, " + raise KeyError("the table name `%s` does not exist in table_name_to_feature_spec, " "please check whether the insert_feature_spec(...) is invoked.", table_name) self.table_name_to_feature_spec.get(table_name)[is_training] = [] logger.debug("The feature spec of the table name `%s` has been cleared.", table_name) diff --git a/mx_rec/util/config_utils/hybrid_mgmt_utils.py b/mx_rec/util/config_utils/hybrid_mgmt_utils.py index 737ce7cb2e8a5e78e440a17c30f03b50d2e70b2b..26624461e30eb13f40f83293d25b5026eeae2134 100644 --- a/mx_rec/util/config_utils/hybrid_mgmt_utils.py +++ b/mx_rec/util/config_utils/hybrid_mgmt_utils.py @@ -83,11 +83,18 @@ class HybridManagerConfig: self.asc_manager.save(root_dir) logger.debug("Data from host pipeline has been saved.") - def restore_host_data(self, root_dir: Optional[str]) -> None: + def restore_host_data(self, root_dir: Optional[str], warm_start_tables=None) -> None: if self.asc_manager is None: raise RuntimeError("ASC manager does not exist.") - - if not self.asc_manager.load(root_dir): + if not warm_start_tables: + warm_start_tables = [] + if not self.asc_manager.load(root_dir, warm_start_tables): raise TypeError("Asc load data does not match usr setups, \ please re-consider if you want to restore from this dir") logger.debug("Data from host pipeline has been restored.") + + def fetch_device_emb(self): + if self.asc_manager is None: + raise RuntimeError("ASC manager not exist.") + self.asc_manager.fetch_device_emb() + logger.debug("request of fetching embedding from device to host for saving has been send") diff --git a/mx_rec/util/cpu.py b/mx_rec/util/cpu.py index f4d299ed4cfe385bfc2411fac83c0412e00a7813..a7848d7f5bca124e2aac8ef65634ab674fe30155 100644 --- a/mx_rec/util/cpu.py +++ b/mx_rec/util/cpu.py @@ -3,7 +3,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import ctypes -from ctypes import * import psutil from mx_rec.util.log import logger @@ -26,7 +25,7 @@ class PcieInfo(ctypes.Structure): ] -def get_card_and_deivce(logic_id): +def get_card_and_device(logic_id): """ 通过芯片逻辑id获取芯片的卡id和device id 一张卡可能有多个芯片,对应多个device_id,但每个芯片的逻辑ID @@ -52,7 +51,7 @@ def get_pcie_id(card_id, device_id): dev = ctypes.c_int(device_id) ret = g_dcmi.dcmi_get_device_pcie_info_v2(card, dev, ctypes.pointer(info)) if ret != 0: - raise OSError("cant get pcie info of device {card_id}:{deivce_id}") + raise OSError(f"cant get pcie info of device {card_id}:{device_id}") pcie_id = f'{info.domain:04X}:{info.bdf_busid:02x}:' pcie_id += f'{info.bdf_deviceid:02x}.{info.bdf_funcid}' return pcie_id @@ -87,7 +86,7 @@ def bind_cpu_by_device_logic_id(logic_id): logger.error(e) return False try: - card_id, device_id = get_card_and_deivce(logic_id) + card_id, device_id = get_card_and_device(logic_id) pcie_id = get_pcie_id(card_id, device_id) numa = get_numa_by_pcie(pcie_id) cpu_list = get_cpu_list_by_numa(numa) diff --git a/mx_rec/util/framework_npu_env/tfa_env.py b/mx_rec/util/framework_npu_env/tfa_env.py index a00fd7ce0d97d3e90ddfd1e50662b5f1e869eeaf..bcd0b0ee448e6d65a441ec4d8ca9c37330f92c3a 100644 --- a/mx_rec/util/framework_npu_env/tfa_env.py +++ b/mx_rec/util/framework_npu_env/tfa_env.py @@ -13,14 +13,12 @@ def set_ascend_env(): 配置昇腾相关的参数和环境变量 """ logger.debug("Ascend env set start.") - os.environ["RANK_ID"] = str(get_rank_id()) device_id = str(get_device_id()) - os.environ["DEVICE_ID"] = device_id os.environ["ASCEND_DEVICE_ID"] = device_id - os.environ["DEVICE_INDEX"] = device_id if global_env.rank_table_file: + os.environ["RANK_ID"] = str(get_rank_id()) rank_size = get_rank_size() os.environ["RANK_SIZE"] = str(rank_size) diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 52b5af46cffd66ef0bf20d9ec2735960310eb55e..313f16936eb219de750d2ec89a63da1914860997 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -22,7 +22,7 @@ from mx_rec.constants.constants import EnvOption, RecPyLogLevel, Flag, EMPTY_STR DEFAULT_HD_CHANNEL_SIZE, DEFAULT_KP_THREAD_NUM, DEFAULT_FAST_UNIQUE_THREAD_NUM, RecCPPLogLevel, MAX_INT32, \ MIN_HD_CHANNEL_SIZE, MAX_HD_CHANNEL_SIZE, MIN_KP_THREAD_NUM, MAX_KP_THREAD_NUM, \ MIN_FAST_UNIQUE_THREAD_NUM, MAX_FAST_UNIQUE_THREAD_NUM, DEFAULT_HOT_EMB_UPDATE_STEP, MIN_HOT_EMB_UPDATE_STEP, \ - MAX_HOT_EMB_UPDATE_STEP, TFDevice + MAX_HOT_EMB_UPDATE_STEP, TFDevice, MAX_CM_WORKER_SIZE, MIN_CM_WORKER_SIZE, DEFAULT_CM_WORKER_SIZE from mx_rec.validator.validator import para_checker_decorator, OptionValidator, DirectoryValidator, Convert2intValidator @@ -30,7 +30,6 @@ from mx_rec.validator.validator import para_checker_decorator, OptionValidator, class RecEnv: mxrec_log_level: str rank_table_file: str - ascend_visible_devices: str cm_chief_device: str cm_worker_size: str tf_device: str @@ -45,9 +44,6 @@ class RecEnv: use_combine_faae: str stat_on: str record_key_count: str - rank_id_env: str - rank_size_env: str - local_rank_size_env: str def get_global_env_conf() -> RecEnv: @@ -58,9 +54,8 @@ def get_global_env_conf() -> RecEnv: rec_env = RecEnv( mxrec_log_level=os.getenv(EnvOption.MXREC_LOG_LEVEL.value, RecPyLogLevel.INFO.value), rank_table_file=os.getenv(EnvOption.RANK_TABLE_FILE.value, EMPTY_STR), - ascend_visible_devices=os.getenv(EnvOption.ASCEND_VISIBLE_DEVICES.value), cm_chief_device=os.getenv(EnvOption.CM_CHIEF_DEVICE.value), - cm_worker_size=os.getenv(EnvOption.CM_WORKER_SIZE.value), + cm_worker_size=os.getenv(EnvOption.CM_WORKER_SIZE.value, DEFAULT_CM_WORKER_SIZE), tf_device=os.getenv(EnvOption.TF_DEVICE.value, TFDevice.NONE.value), acl_timeout=os.getenv(EnvOption.ACL_TIMEOUT.value, "-1"), hd_channel_size=os.getenv(EnvOption.HD_CHANNEL_SIZE.value, DEFAULT_HD_CHANNEL_SIZE), @@ -72,10 +67,7 @@ def get_global_env_conf() -> RecEnv: glog_stderrthreahold=os.getenv(EnvOption.GLOG_STDERRTHREAHOLD.value, RecCPPLogLevel.INFO.value), use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value), - record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value), - rank_id_env=os.getenv(EnvOption.OMPI_COMM_WORLD_RANK.value), - rank_size_env=os.getenv(EnvOption.OMPI_COMM_WORLD_LOCAL_SIZE.value), - local_rank_size_env=os.getenv(EnvOption.OMPI_COMM_WORLD_LOCAL_SIZE.value), + record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value) ) return rec_env @@ -84,6 +76,8 @@ def get_global_env_conf() -> RecEnv: @para_checker_decorator(check_option_list=[ ("mxrec_log_level", OptionValidator, {"options": [i.value for i in list(RecPyLogLevel)]}), ("rank_table_file", DirectoryValidator, {}, ["check_exists_if_not_empty"]), + ("cm_worker_size", Convert2intValidator, {"min_value": MIN_CM_WORKER_SIZE, "max_value": MAX_CM_WORKER_SIZE}, + ["check_value"]), ("tf_device", OptionValidator, {"options": [i.value for i in list(TFDevice)]}), ("acl_timeout", Convert2intValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), ("hd_channel_size", Convert2intValidator, diff --git a/mx_rec/util/normalization.py b/mx_rec/util/normalization.py index dc9dd2c11c8ad2e2bc841631a993b1993db2bf0e..a9b2513238760e224356cd30a5e0cdd3cc770a65 100644 --- a/mx_rec/util/normalization.py +++ b/mx_rec/util/normalization.py @@ -33,6 +33,6 @@ def fix_invalid_table_name(name): if not fix_name: raise ValueError(f"The table name '{name}' doesn't contain valid character, " f"according to the rule '{pattern}'") - logger.warning(f"The table name '%s' contains invalid characters. The system automatically " - f"remove invalid characters. The table name was changed to '%s'", name, fix_name) + logger.warning("The table name '%s' contains invalid characters. The system automatically " + "remove invalid characters. The table name was changed to '%s'", name, fix_name) return fix_name diff --git a/mx_rec/util/perf.py b/mx_rec/util/perf.py index 3feb733258a0a26b05c11412c32e5b7efbcfd7eb..81089f6329bc62b72426c134402a4bdd9b9ab300 100644 --- a/mx_rec/util/perf.py +++ b/mx_rec/util/perf.py @@ -26,7 +26,7 @@ def performance(method_name): start = time.perf_counter() result = func(*args, **kwargs) span = time.perf_counter() - start - logger.debug(f"%s method consume %s (s).", method_name, round(span, 6)) + logger.debug("%s method consume %s (s).", method_name, round(span, 6)) return result return wrapper return decorator diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py index 2c9f49a9e477f27ff707d6c86dac830c2eef463d..0040e2b5a661ca6411e863e8aa54a4c76a6e8d6c 100644 --- a/mx_rec/util/variable.py +++ b/mx_rec/util/variable.py @@ -27,11 +27,6 @@ def get_dense_and_sparse_variable(): return dense_variables, sparse_variables -def check_and_get_config_via_var(variable, optimizer_type: str): +def get_config_via_var(variable): table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(variable) - - if not table_instance.is_hbm and not table_instance.optimizer: - raise EnvironmentError(f"When ASC with DDR, you must pass the '{optimizer_type}' optimizer instances to the" - f" init method of SparseEmbedding.") - return table_instance diff --git a/mx_rec/validator/emb_validator.py b/mx_rec/validator/emb_validator.py index c9d18f05c1860880654dc8c22d04d281af03cc84..e4417b6d7718ab6fccf0a5037616ed3a688bbaf7 100644 --- a/mx_rec/validator/emb_validator.py +++ b/mx_rec/validator/emb_validator.py @@ -57,8 +57,8 @@ def check_emb_lookup_params(table_params: dict, feature_spec: Union[tf.Tensor, F slice_device_vocabulary_size = table_params.get("slice_device_vocabulary_size") slice_host_vocabulary_size = table_params.get("slice_host_vocabulary_size") table_name = table_params.get("table_name") - if slice_host_vocabulary_size + slice_device_vocabulary_size > MAX_VOCABULARY_SIZE: - raise ValueError(f"Given device_vocabulary_size and host_vocabulary_size was too big for table " + if slice_host_vocabulary_size > MAX_VOCABULARY_SIZE: + raise ValueError(f"given host_vocabulary_size was too big for table " f"'{table_name}', in which slice_device_vocabulary_size was " f"{slice_device_vocabulary_size} and slice_host_vocabulary_size was " f"{slice_host_vocabulary_size}.") @@ -78,14 +78,14 @@ def check_emb_lookup_params(table_params: dict, feature_spec: Union[tf.Tensor, F if slice_device_vocabulary_size < send_count * rank_size: raise ValueError(f"Given device_vocabulary_size was too small for table '{table_name}', " f"in which slice_device_vocabulary_size was {slice_device_vocabulary_size} " - f"and send_count({send_count}) * rank_size({rank_size}) was " - f"{send_count * rank_size}.") + f"and it must be bigger than send_count({send_count}) * rank_size({rank_size}): " + f"{send_count * rank_size}, please increase [device vocabSize] in [create_table] interface") if slice_host_vocabulary_size < send_count * rank_size: raise ValueError(f"Given host_vocabulary_size was too small for table '{table_name}', " f"in which slice_host_vocabulary_size was {slice_host_vocabulary_size} " - f"and send_count({send_count}) * rank_size({rank_size}) was " - f"{send_count * rank_size}.") + f"and it must be bigger than send_count({send_count}) * rank_size({rank_size}): " + f"{send_count * rank_size}, please increase [host vocabSize] in [create_table] interface") def check_emb_multi_lookup_times(lookup_times: int, table_name: str): diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index c9abde8757c7ee335e4efaa3f0edcff577059d14..013fe5658b164588bf85a90a78a4ab5203633c86 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -437,7 +437,14 @@ class IntValidator(NumValidator): def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, invalid_options: List = None, constrained_options: List = None, msg: str = ""): super(IntValidator, self).__init__(name, value, min_value, max_value, invalid_options, constrained_options, msg) - self.register_checker(lambda: isinstance(self.value, int), msg if msg else f"type of '{name}' is not int") + + def check_type(): + if isinstance(self.value, bool): + # bool is subclass of int + return False + return isinstance(self.value, int) + + self.register_checker(check_type, msg if msg else f"type of '{name}' is not int") class OptionalIntValidator(IntValidator): diff --git a/setup.py b/setup.py index efb4c9940f1fc7913d51962603f64b9b545545e1..87454130bc412e444763aa76edae447f5eba8d1a 100644 --- a/setup.py +++ b/setup.py @@ -16,64 +16,39 @@ # ============================================================================== import os +import glob import stat -from setuptools import setup, find_packages -import pkg_resources -from setuptools.extern.packaging import version as packaging_version - - -# Patch Version class to preserve original version string -class NoNormalizeVersion(packaging_version.Version): - def __init__(self, version): - self._orig_version = version - super().__init__(version) - - def __str__(self): - return self._orig_version - - -packaging_version.Version = NoNormalizeVersion -# Patch safe_version() to prevent version normalization -pkg_resources.safe_version = lambda v: v - -try: - with open("README.md") as file: - LONG_DESCRIPTION = file.read() -except IOError: - LONG_DESCRIPTION = "" - -env_version = os.getenv("VERSION") -VERSION = env_version if env_version is not None else '5.0.rc3' - -INIT_FILE = "mx_rec/__init__.py" -with open(INIT_FILE, 'r') as file: - lines = file.readlines() - -for idx, line in enumerate(lines): - if "__version__ = " not in line: - continue - lines[idx] = f"__version__ = '{VERSION}'\n" - break - -FLAG = os.O_WRONLY | os.O_TRUNC -MODE = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH -with os.fdopen(os.open(INIT_FILE, FLAG, MODE), 'w') as out: - out.writelines(lines) - -setup( - name='mx_rec', - version=VERSION, - author='HUAWEI Inc', - description='MindX SDK Recommend', - long_description=LONG_DESCRIPTION, - # include mx_rec - packages=find_packages( - where='.', - include=["mx_rec*"] - ), - package_dir={}, - # other file - package_data={'': ['tools/*', 'tools/*/*', '*.yml', '*.sh', '*.so*']}, - # dependency - python_requires='>=3.7.5' -) +import shutil +import subprocess + +# get the absolute path of the Python 3.7 program +res = subprocess.run(["/usr/bin/which", "python3.7"], stdout=subprocess.PIPE, text=True, shell=False) +if res.returncode: + raise RuntimeError("get the absolute path of the Python 3.7 program failed!") +python37_path = res.stdout.strip() + +# add execution permission to the file with the .sh suffix +scripts = glob.glob(os.path.join(os.getcwd(), "build/*.sh")) +for script in scripts: + if os.path.isfile(script): + os.chmod(script, os.stat(script).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + +# clean pkg_dir existed +PKG_DIR = "./build/mindxsdk-mxrec" +if os.path.exists(PKG_DIR): + shutil.rmtree(PKG_DIR) + +# build tf1's wheel file +res = subprocess.run([python37_path, "setup_tf1.py", "bdist_wheel"], shell=False) +if res.returncode: + raise RuntimeError(f"build tf1's wheel file failed!") + +# build tf2's wheel file +res = subprocess.run([python37_path, "setup_tf2.py", "bdist_wheel"], shell=False) +if res.returncode: + raise RuntimeError(f"build tf2's wheel file failed!") + +# copy cust_op, examples files, etc. Then gen mxrec's tar pkg +res = subprocess.run(["./build/gen_mxrec_tar_pkg.sh"], shell=False) +if res.returncode: + raise RuntimeError(f"gen mxrec's tar pkg failed!") diff --git a/setup_tf1.py b/setup_tf1.py new file mode 100644 index 0000000000000000000000000000000000000000..df8c731edf47fe5ab5b1ad2d01085404fbbf075b --- /dev/null +++ b/setup_tf1.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import stat +import subprocess +from setuptools import setup, find_packages +import pkg_resources +from setuptools.extern.packaging import version as packaging_version + +script_path = os.getcwd() + + +# Patch Version class to preserve original version string +class NoNormalizeVersion(packaging_version.Version): + def __init__(self, version): + self._orig_version = version + super().__init__(version) + + def __str__(self): + return self._orig_version + + +def safe_version(v): + return v + + +packaging_version.Version = NoNormalizeVersion +# Patch safe_version() to prevent version normalization +pkg_resources.safe_version = safe_version + +try: + with open("README.md") as file: + LONG_DESCRIPTION = file.read() +except IOError: + LONG_DESCRIPTION = "" + +env_version = os.getenv("VERSION") +VERSION = env_version if env_version is not None else '6.0.RC2' + +INIT_FILE = "mx_rec/__init__.py" +with open(INIT_FILE, 'r') as file: + lines = file.readlines() + +for idx, line in enumerate(lines): + if "__version__ = " not in line: + continue + lines[idx] = f"__version__ = '{VERSION}'\n" + break + +FLAG = os.O_WRONLY | os.O_TRUNC +MODE = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH +with os.fdopen(os.open(INIT_FILE, FLAG, MODE), 'w') as out: + out.writelines(lines) + +# compile so files +tf1_script = os.path.join(script_path, "./build/build_tf1.sh") +res = subprocess.run([tf1_script], shell=False) +if res.returncode: + raise RuntimeError("compile so files failed!") + +setup( + name='mx_rec', + version=VERSION, + author='HUAWEI Inc', + description='MindX SDK Recommend', + long_description=LONG_DESCRIPTION, + # include mx_rec + packages=find_packages( + where='.', + include=["mx_rec*"] + ), + # other file + package_data={'': ['tools/*', 'tools/*/*', '*.yml', '*.sh', '*.so*']}, + # dependency + python_requires='>=3.7.5' +) + +move_whl_script = os.path.join(script_path, "./build/move_whl_file_2_pkg_dir.sh") +res = subprocess.run([move_whl_script, "tf1"], shell=False) +if res.returncode: + raise RuntimeError(f"move tf1 whl file to pkg dir failed!") diff --git a/setup_tf2.py b/setup_tf2.py new file mode 100644 index 0000000000000000000000000000000000000000..31e61a998d25a2b8b59b91b98bcb2fa5386601fd --- /dev/null +++ b/setup_tf2.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import stat +import subprocess +from setuptools import setup, find_packages +import pkg_resources +from setuptools.extern.packaging import version as packaging_version + +script_path = os.getcwd() + + +# Patch Version class to preserve original version string +class NoNormalizeVersion(packaging_version.Version): + def __init__(self, version): + self._orig_version = version + super().__init__(version) + + def __str__(self): + return self._orig_version + + +def safe_version(v): + return v + + +packaging_version.Version = NoNormalizeVersion +# Patch safe_version() to prevent version normalization +pkg_resources.safe_version = safe_version + +try: + with open("README.md") as file: + LONG_DESCRIPTION = file.read() +except IOError: + LONG_DESCRIPTION = "" + +env_version = os.getenv("VERSION") +VERSION = env_version if env_version is not None else '6.0.RC2' + +INIT_FILE = "mx_rec/__init__.py" +with open(INIT_FILE, 'r') as file: + lines = file.readlines() + +for idx, line in enumerate(lines): + if "__version__ = " not in line: + continue + lines[idx] = f"__version__ = '{VERSION}'\n" + break + +FLAG = os.O_WRONLY | os.O_TRUNC +MODE = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH +with os.fdopen(os.open(INIT_FILE, FLAG, MODE), 'w') as out: + out.writelines(lines) + +# compile so files +tf2_script = os.path.join(script_path, "./build/build_tf2.sh") +res = subprocess.run([tf2_script], shell=False) +if res.returncode: + raise RuntimeError("compile so files failed!") + +setup( + name='mx_rec', + version=VERSION, + author='HUAWEI Inc', + description='MindX SDK Recommend', + long_description=LONG_DESCRIPTION, + # include mx_rec + packages=find_packages( + where='.', + include=["mx_rec*"] + ), + # other file + package_data={'': ['tools/*', 'tools/*/*', '*.yml', '*.sh', '*.so*']}, + # dependency + python_requires='>=3.7.5' +) + +move_whl_script = os.path.join(script_path, "./build/move_whl_file_2_pkg_dir.sh") +res = subprocess.run([move_whl_script, "tf2"], shell=False) +if res.returncode: + raise RuntimeError(f"move tf2 whl file to pkg dir failed!") diff --git a/src/AccCTR/3rdparty/CMakeLists.txt b/src/AccCTR/3rdparty/CMakeLists.txt index a17e472c0ba0a1706c8f73d7042393932f89a174..3a05f5859d389c40245b8e2e22815b44b30b7b02 100644 --- a/src/AccCTR/3rdparty/CMakeLists.txt +++ b/src/AccCTR/3rdparty/CMakeLists.txt @@ -1,3 +1,17 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + message("build mode " ${BUILD_MODE}) set(PLATFORM_UTILITIES_3RDPARTY_SOURCE_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) diff --git a/src/AccCTR/CMakeLists.txt b/src/AccCTR/CMakeLists.txt index 0cb6317607e7fa2547c62a3ad6b3f3542953e0e3..febf1740793b35f6c61de5baa9eda0ca45f73e6c 100644 --- a/src/AccCTR/CMakeLists.txt +++ b/src/AccCTR/CMakeLists.txt @@ -23,8 +23,6 @@ if (${BUILD_MODE} MATCHES "release") -Wall -fPIC -fms-extensions - -Wno-unused-parameter - -Wno-unused-function -Wunused-variable -Wunused-value -Wcast-align @@ -47,8 +45,6 @@ elseif (${BUILD_MODE} MATCHES "debug") -Wall -fPIC -fms-extensions - -Wno-unused-parameter - -Wno-unused-function -Wunused-variable -Wunused-value -Winvalid-pch @@ -67,8 +63,6 @@ elseif (${BUILD_MODE} MATCHES "ut") -Wall -fPIC -fms-extensions - -Wno-unused-parameter - -Wno-unused-function -Wunused-variable -Wunused-value -Winvalid-pch @@ -79,10 +73,10 @@ elseif (${BUILD_MODE} MATCHES "ut") -Wfloat-equal -Wextra -std=c++17 - #-fsanitize=address - #-fno-omit-frame-pointer - #-fstack-protector-all - #-fstack-protector-strong + -fsanitize=address + -fsanitize-recover=address,all + -fno-omit-frame-pointer + -fstack-protector-all ) else () message(FATAL_ERROR "======BUILD_MODE not found") @@ -100,7 +94,6 @@ elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64") ${CXX_FLAGS} -msse2 -mavx - #-w ) else () message(FATAL_ERROR "don't support ${CMAKE_HOST_SYSTEM_PROCESSOR}") @@ -110,6 +103,11 @@ set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) message(===============${OCK_CTR_PLATFORM_UTIL_DIR}) include_directories(${OCK_CTR_PLATFORM_UTIL_DIR}/securec/include) +include_directories( + ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/src/embedding_cache +) + add_subdirectory(3rdparty) add_subdirectory(src) diff --git a/src/AccCTR/README.md b/src/AccCTR/README.md index 1a394699500f35176fb91b6549fb78c9859674c6..1b25534dff109f24cb9333a23155c8d13122ecd6 100644 --- a/src/AccCTR/README.md +++ b/src/AccCTR/README.md @@ -6,4 +6,6 @@ 2、bash build.sh debug //编译debug -3、bash build.sh ut //编译并运行ut,覆盖率在tests/build/cov/gen目录下 +3、编译和运行UT: + (1)bash build.sh ut //编译ut,覆盖率在tests/build/cov/gen目录下 + (2)cd build && bash build_test.sh ut //进入到build目录下并运行ut \ No newline at end of file diff --git a/src/AccCTR/build/build_test.sh b/src/AccCTR/build/build_test.sh index 9441efe39f8e65d1c3e0dbda243409cca8ee73ba..4001b825273a57a50e24258be6c4fa8132ec28bd 100644 --- a/src/AccCTR/build/build_test.sh +++ b/src/AccCTR/build/build_test.sh @@ -24,6 +24,9 @@ TOOL_FILE="create_fake_id.py" CPU_TYPE=$(arch) BUILD_MODE=$1 +# config asan environment variable +export ASAN_OPTIONS=halt_on_error=1:detect_leaks=1 + create_data() { cd ${TOOL_PATH} diff --git a/src/AccCTR/src/CMakeLists.txt b/src/AccCTR/src/CMakeLists.txt index 09da4670f274e4f4b9f3806e7a238d6e971702a4..1f4d92695f60bf9b80277c9bcead6cd983dfb963 100644 --- a/src/AccCTR/src/CMakeLists.txt +++ b/src/AccCTR/src/CMakeLists.txt @@ -23,12 +23,17 @@ set(OUTPUT ${PROJECT_SOURCE_DIR}/output) set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) - if (${BUILD_MODE} MATCHES "ut") add_compile_options(-ftest-coverage -fprofile-arcs) link_libraries(gcov) +else() + add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) # must set this option otherwise pybind will not find embCache symbol endif (${BUILD_MODE} MATCHES "ut") +if (${BUILD_MODE} MATCHES "fuzz") + add_compile_options(-ftest-coverage -fprofile-arcs -fdump-rtl-expand) + link_libraries(gcov asan) +endif (${BUILD_MODE} MATCHES "fuzz") message("include : " ${OCK_CTR_SRC_INCLUDE_DIR}) @@ -37,6 +42,7 @@ set(LIB_HW_SECURE ${OCK_CTR_PLATFORM_UTIL_DIR}/securec/lib/libsecurec.so) add_subdirectory(include) add_subdirectory(common) add_subdirectory(unique) +add_subdirectory(embedding_cache) file(GLOB_RECURSE CTR_SRC factory_impl.cpp) @@ -52,6 +58,7 @@ target_include_directories(_ock_ctr_common target_link_libraries(_ock_ctr_common PUBLIC -Wl,--start-group unique + embedding_cache dl utils ${LIB_HW_SECURE} diff --git a/src/AccCTR/src/common/util/error_code.h b/src/AccCTR/src/common/util/error_code.h index 04d26a5730bd1839ae82307f06dab878ba4264bf..87c8ffe61ca2dbc6f2f87692ac299d501ef2d9e0 100644 --- a/src/AccCTR/src/common/util/error_code.h +++ b/src/AccCTR/src/common/util/error_code.h @@ -29,7 +29,21 @@ using CTRCode = enum : int { H_OUTPUT_TYPE_ERROR = 8, H_SCENE_ERROR = 9, H_MEMORY_ALLOC_ERROR = 10, - H_UNIQUE_UNINITIALIZED_ERROR = 11 + H_UNIQUE_UNINITIALIZED_ERROR = 11, + H_TABLE_NOT_EXIST = 12, + H_LOAD_ERROR = 13, + H_INITIALIZER_INVALID = 14, + H_EXT_EMBEDDING_SIZE_INVALID = 15, + H_MAX_CACHESIZE_TOO_SMALL = 16, + H_HOST_VOCAB_SIZE_TOO_SMALL = 17, + H_THREAD_NUM_ERROR = 18, + H_TABLE_CREATE_DUPLICATE = 19, + H_ARG_NOT_EMPTY = 20, + H_SIZE_ZERO = 21, + H_TABLE_NAME_EMPTY = 22, + H_PREFILL_BUFFER_SIZE_INVALID = 23, + H_TABLE_NAME_TOO_LONG = 24, + H_EMB_CACHE_INFO_LOST = 25 }; } } diff --git a/src/AccCTR/src/common/util/external_threader.h b/src/AccCTR/src/common/util/external_threader.h index 5a1132af94dbb729680153f7163dba6213bae4ec..5f7c500fa1cbaeda0aae482815406a1e927651f3 100644 --- a/src/AccCTR/src/common/util/external_threader.h +++ b/src/AccCTR/src/common/util/external_threader.h @@ -20,11 +20,81 @@ limitations under the License. #include #include #include +#include +#include +#include +#include #include "singleton.h" using ExternalThread = void (*)(const std::vector> &tasks); namespace ock { +class ThreadPoolAsync { +public: + ThreadPoolAsync() : stop(false) {} + + ~ThreadPoolAsync() + { + { + std::lock_guard lock(taskMutex); + stop = true; + } + taskCv.notify_all(); + for (auto &t : workerThreads) { + t.join(); + } + } + + void SetNumThreads(int n) + { + if (n < 1) { + return; + } + + for (int i = 0; i < n; ++i) { + workerThreads.emplace_back(std::bind(&ThreadPoolAsync::WorkerThread, this)); + } + } + + template std::future AddTask(F &&f) + { + std::lock_guard lock(taskMutex); + + auto pt = std::make_unique>(std::forward(f)); + auto fut = pt->get_future(); + tasks.emplace(std::move(pt)); + taskCv.notify_one(); + return fut; + } + +private: + std::vector workerThreads; + std::queue>> tasks; + std::mutex taskMutex; + std::condition_variable taskCv; + std::atomic stop = false; + + void WorkerThread() + { + while (true) { + std::unique_ptr> task; + { + std::unique_lock lock(taskMutex); + while (tasks.empty() && !stop) { + taskCv.wait(lock); + } + if (stop) { + break; + } + task = std::move(tasks.front()); + tasks.pop(); + } + (*task)(); + } + } +}; + + class SimpleThreadPool { public: static void SyncRun(const std::vector> &tasks) diff --git a/src/AccCTR/src/embedding_cache/CMakeLists.txt b/src/AccCTR/src/embedding_cache/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0278a6ebadf51d705b1f3f8c86809ed9c458d64 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +file(GLOB_RECURSE SRCS *.cpp *.h) + +add_library(embedding_cache OBJECT ${SRCS}) + +target_link_libraries(embedding_cache + -Wl,--start-group + -Wl,--end-group + ) + +target_include_directories(embedding_cache + PUBLIC + ${PROJECT_SOURCE_DIR}/src/common/util + ${PROJECT_SOURCE_DIR}/src/include) \ No newline at end of file diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5257882036a2ccbd3e52096bfee4b0aa3b1720b3 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -0,0 +1,492 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "cache_manager.h" + +#include + +#include "external_logger.h" + +using namespace EmbCache; +using namespace ock; +using namespace ock::ctr; + +int64_t EmbCache::INVALID_KEY = -1; + +int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, + const std::vector& initializerInfos, int64_t invalidKey, + uint64_t prefillBufferSize, uint32_t refillThreadNum) +{ + int checkTableNameRet = CheckCreateTableName(embCacheInfo.tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (embCacheInfo.extEmbeddingSize == 0 || embCacheInfo.embeddingSize == 0 || embCacheInfo.vocabSize == 0 || + embCacheInfo.maxCacheSize == 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "size must be positive"); + return H_SIZE_ZERO; + } + + if (embCacheInfo.vocabSize < embCacheInfo.maxCacheSize) { + ExternalLogger::PrintLog(LogLevel::ERROR, "host vocabSize:" + std::to_string(embCacheInfo.vocabSize) + + " must be greater than or equal to device vocabSize:" + std::to_string(embCacheInfo.maxCacheSize) + + ", please increase [host vocabSize] in [create_table] interface"); + return H_HOST_VOCAB_SIZE_TOO_SMALL; + } + + auto om = offsetMappers.find(embCacheInfo.tableName); + auto embTable = embTables.find(embCacheInfo.tableName); + if (om != offsetMappers.end() || embTable != embTables.end()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "This table has already been created"); + return H_TABLE_CREATE_DUPLICATE; + } + + if (embCacheInfo.extEmbeddingSize % embCacheInfo.embeddingSize != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "extEmbeddingSize = embeddingSize + optimizerSize, " + "which is divisible by embeddingSize"); + return H_EXT_EMBEDDING_SIZE_INVALID; + } + + if (!CheckInitializer(embCacheInfo.extEmbeddingSize, initializerInfos)) { + return H_INITIALIZER_INVALID; + } + + if ((prefillBufferSize < 1) || (prefillBufferSize > embCacheInfo.vocabSize)) { + ExternalLogger::PrintLog(LogLevel::ERROR, "PrefillBufferSize: " + std::to_string(prefillBufferSize) + + " has to be between [1, hostVocabSize]."); + return H_PREFILL_BUFFER_SIZE_INVALID; + } + + if (!CheckValidThreadNum(refillThreadNum)) { + return H_THREAD_NUM_ERROR; + } + + uint32_t reserveDevice = embCacheInfo.maxCacheSize / VOCAB_CACHE_RATIO; + if (!offsetMappers[embCacheInfo.tableName].Initialize(reserveDevice, embCacheInfo.maxCacheSize)) { + offsetMappers[embCacheInfo.tableName].UnInitialize(); + offsetMappers.erase(embCacheInfo.tableName); + return H_MEMORY_ALLOC_ERROR; + } + + EmbPoolParam embPoolParam{prefillBufferSize, refillThreadNum}; + uint32_t reserveHost = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; + if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo, reserveHost, initializerInfos, embPoolParam)) { + offsetMappers.erase(embCacheInfo.tableName); + embTables.erase(embCacheInfo.tableName); + return H_MEMORY_ALLOC_ERROR; + } + + embCacheInfos.insert({embCacheInfo.tableName, embCacheInfo}); + INVALID_KEY = invalidKey; + return H_OK; +} + +int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(const std::string& tableName, std::vector& keys, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) +{ + int checkRet = CheckGetSwapPairsAndKey2Offset(tableName, swapInKoPair, swapOutKoPair); + if (checkRet != H_OK) { + return checkRet; + } + return offsetMappers[tableName].GetSwapPairsAndKey2Offset(keys, swapInKoPair, swapOutKoPair); +} + +int EmbCacheManagerImpl::EmbeddingLookup(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + if (embAddr == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); + return H_ADDRESS_NULL; + } + + return embTables[tableName].Gather(reinterpret_cast(embAddr), keys, threadNum); +} + +int EmbCacheManagerImpl::EmbeddingLookupAddrs(const std::string& tableName, const std::vector& keys, + std::vector& addrs, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + return embTables[tableName].GatherAddrs(keys, addrs, threadNum); +} + +// 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 +int EmbCacheManagerImpl::EmbeddingLookupAndRemove(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + if (embAddr == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); + return H_ADDRESS_NULL; + } + + return embTables[tableName].GatherAndRemove(reinterpret_cast(embAddr), keys, threadNum); +} + +int EmbCacheManagerImpl::EmbeddingUpdate(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { // 检查thread是否小于核数 + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + if (embAddr == nullptr) { // 检查embAddr是不是空指针 + ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); + return H_ADDRESS_NULL; + } + + return embTables[tableName].Scatter(reinterpret_cast(embAddr), keys, threadNum); +} + +int EmbCacheManagerImpl::EmbeddingRemove(const std::string& tableName, const std::vector& keys, + uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { // 检查thread是否小于核数 + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + return embTables[tableName].RemoveByKeys(keys, threadNum); +} + +int EmbCacheManagerImpl::RemoveEmbsByKeys(const std::string& tableName, const std::vector& keys) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + const auto& om = offsetMappers.find(tableName); + const auto& embTable = embTables.find(tableName); + for (auto key : keys) { + if (key == static_cast(INVALID_KEY)) { + ExternalLogger::PrintLog(LogLevel::WARN, "Try to evict invalid key"); + continue; + } + om->second.Remove(key); + embTable->second.Remove(key); + } + return H_OK; +} + +int EmbCacheManagerImpl::GetEmbTableNames(std::vector& allTableNames) +{ + if (!allTableNames.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "allTableNames should be empty"); + return H_ARG_NOT_EMPTY; + } + allTableNames.reserve(embTables.size()); + for (auto& embTable : embTables) { + allTableNames.emplace_back(embTable.first); + } + return H_OK; +} + +int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(const std::string& tableName, + std::vector>& koVec) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + koVec = offsetMappers[tableName].ExportSortedKVPairs(); + return H_OK; +} + +int EmbCacheManagerImpl::Serialize(const std::string& tableName, std::vector& buffer) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + buffer = embTables[tableName].Serialize(); + return H_OK; +} + +int EmbCacheManagerImpl::Deserialize(const std::string& tableName, const std::vector& buffer) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + if (!embTables[tableName].Deserialize(buffer)) { + return H_LOAD_ERROR; + } + return H_OK; +} + +int EmbCacheManagerImpl::GetEmbTableInfos(std::string tableName, std::vector& keys, + std::vector>& embeddings, + std::vector>& optimizerSlots) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + if (!keys.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "keys should be empty"); + return H_ARG_NOT_EMPTY; + } + if (!embeddings.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "embeddings should be empty"); + return H_ARG_NOT_EMPTY; + } + if (!optimizerSlots.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "optimizerSlots should be empty"); + return H_ARG_NOT_EMPTY; + } + embTables[tableName].GetEmbTableInfos(keys, embeddings, optimizerSlots); + return H_OK; +} + +int EmbCacheManagerImpl::LoadEmbTableInfos(std::string tableName, const std::vector& keys, + const std::vector>& embeddings, + const std::vector>& optimizerSlots) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + if (!embTables[tableName].LoadEmbTableInfos(keys, embeddings, optimizerSlots)) { + return H_LOAD_ERROR; + } + return H_OK; +} + +int EmbCacheManagerImpl::BackUpTrainStatus(const std::string& tableName) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + // Back up the key-offset correspondence on the device + kvVecsBackUp[tableName] = offsetMappers[tableName].ExportVec(); + + auto embInfo = embCacheInfos.find(tableName); + if (embInfo == embCacheInfos.end()) { + return H_EMB_CACHE_INFO_LOST; + } + uint32_t reserve = embInfo->second.maxCacheSize / VOCAB_CACHE_RATIO; + uint32_t maxCacheSize = embInfo->second.maxCacheSize; + + auto om = offsetMappersBackUp.find(tableName); + if (om != offsetMappersBackUp.end()) { + offsetMappersBackUp[tableName].UnInitialize(); + } + offsetMappersBackUp[tableName].Initialize(reserve, maxCacheSize); + offsetMappersBackUp[tableName] = offsetMappers[tableName]; + + return H_OK; +} + +int EmbCacheManagerImpl::RecoverTrainStatus(const std::string& tableName) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + auto embInfo = embCacheInfos.find(tableName); + if (embInfo == embCacheInfos.end()) { + return H_EMB_CACHE_INFO_LOST; + } + uint32_t reserve = embInfo->second.maxCacheSize / VOCAB_CACHE_RATIO; + uint32_t maxCacheSize = embInfo->second.maxCacheSize; + + offsetMappers[tableName].UnInitialize(); + offsetMappers[tableName].Initialize(reserve, maxCacheSize); + offsetMappers[tableName] = offsetMappersBackUp[tableName]; + + // Recover the key-offset correspondence on the device + auto kvVecBackUp = kvVecsBackUp[tableName]; + for (const auto& kvPair: kvVecBackUp) { + offsetMappers[tableName].Put(kvPair.first, kvPair.second); + } + + kvVecBackUp.clear(); + return H_OK; +} + +void EmbCacheManagerImpl::Destroy() +{ + for (auto it = offsetMappers.begin(); it != offsetMappers.end(); it++) { + it->second.UnInitialize(); + } + for (auto it = embTables.begin(); it != embTables.end(); it++) { + it->second.UnInitialize(); + } + embCacheInfos.clear(); + offsetMappers.clear(); + embTables.clear(); +} + +int EmbCacheManagerImpl::CheckValidTableName(const std::string& tableName) +{ + if (tableName.size() > TABLE_NAME_MAX_SIZE) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); + return H_TABLE_NAME_TOO_LONG; + } + auto om = offsetMappers.find(tableName); + auto embTable = embTables.find(tableName); + if (om == offsetMappers.end() || embTable == embTables.end()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "can not find table"); + return H_TABLE_NOT_EXIST; + } + return H_OK; +} + +bool EmbCacheManagerImpl::CheckInitializer(uint32_t extEmbSize, std::vector initializerInfos) +{ + std::sort(initializerInfos.begin(), initializerInfos.end(), + [](const auto& u, const auto& v) { return u.start < v.start; }); + uint32_t cur_pos = 0; + for (const auto& info : initializerInfos) { + if (info.initializer == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "initializer is nullptr"); + return false; + } + if (info.start != cur_pos) { + ExternalLogger::PrintLog(LogLevel::ERROR, "Initializers got coverage problems"); + return false; + } + cur_pos += info.len; + } + // 最后判断 + if (cur_pos != extEmbSize) { + ExternalLogger::PrintLog(LogLevel::ERROR, "Initializers got coverage problems"); + return false; + } + return true; +} + +bool EmbCacheManagerImpl::CheckValidThreadNum(uint32_t threadNum) +{ + uint32_t processCoreNum = std::thread::hardware_concurrency(); + if (threadNum > processCoreNum) { + ExternalLogger::PrintLog(LogLevel::ERROR, "ThreadNum can not larger than cpu core num"); + return false; + } + + if (threadNum == 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "ThreadNum can not be zero"); + return false; + } + return true; +} + +int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(const std::string& tableName, const KeyOffsetPair& swapInKoPair, + const KeyOffsetPair& swapOutKoPair) +{ + if (!swapInKoPair.first.empty() || !swapInKoPair.second.empty() || !swapOutKoPair.first.empty() || + !swapOutKoPair.second.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "koPair should be empty"); + return H_ARG_NOT_EMPTY; + } + + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + return H_OK; +} + +int EmbCacheManagerImpl::CheckCreateTableName(const std::string& tableName) +{ + if (tableName.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "tableName can not be empty"); + return H_TABLE_NAME_EMPTY; + } + + if (tableName.size() > TABLE_NAME_MAX_SIZE) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); + return H_TABLE_NAME_TOO_LONG; + } + return H_OK; +} + +uint32_t EmbCacheManagerImpl::GetUsage(const std::string& tableName) +{ + return embTables[tableName].GetUsage(); +} + +int EmbCacheManagerImpl::ResetOffsetMappers() +{ + for (auto it = offsetMappers.begin(); it != offsetMappers.end(); it++) { + auto embInfo = embCacheInfos.find(it->first); + if (embInfo == embCacheInfos.end()) { + return H_EMB_CACHE_INFO_LOST; + } + it->second.UnInitialize(); + uint32_t reserve = embInfo->second.maxCacheSize / VOCAB_CACHE_RATIO; + it->second.Initialize(reserve, embInfo->second.maxCacheSize); + } + return H_OK; +} diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..e4a240ae4945556dac3ec724c39d0485b569b84e --- /dev/null +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h @@ -0,0 +1,103 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef EMBEDDING_CACHE_MANAGER_H +#define EMBEDDING_CACHE_MANAGER_H + +#include +#include +#include +#include + +#include "embedding_cache.h" +#include "embedding_local_table/emb_local_table.h" +#include "error_code.h" +#include "offset_mapper/offset_mapper.h" + +namespace EmbCache { +class EmbCacheManagerImpl : public EmbCacheManager { +public: + EmbCacheManagerImpl() = default; + + ~EmbCacheManagerImpl() override = default; + + int CreateCacheForTable(const EmbCacheInfo& embCacheInfo, const std::vector& initializerInfos, + int64_t invalidKey, uint64_t prefillBufferSize, uint32_t refillThreadNum) override; + + int GetSwapPairsAndKey2Offset(const std::string& tableName, std::vector& keys, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) override; + + int EmbeddingLookup(const std::string& tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) override; + + int EmbeddingLookupAddrs(const std::string& tableName, const std::vector& keys, + std::vector& addrs, uint32_t threadNum) override; + + int EmbeddingUpdate(const std::string& tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) override; + + int EmbeddingRemove(const std::string& tableName, const std::vector& keys, uint32_t threadNum) override; + + int EmbeddingLookupAndRemove(const std::string& tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) override; + + int RemoveEmbsByKeys(const std::string& tableName, const std::vector& keys) override; + + int GetEmbTableNames(std::vector& allTableNames) override; + + int ExportDeviceKeyOffsetPairs(const std::string& tableName, + std::vector>& koVec) override; + + int Serialize(const std::string& tableName, std::vector& buffer) override; + + int Deserialize(const std::string& tableName, const std::vector& buffer) override; + + void Destroy() override; + + int GetEmbTableInfos(std::string tableName, std::vector& keys, + std::vector>& embeddings, + std::vector>& optimizerSlots) override; + + int LoadEmbTableInfos(std::string tableName, const std::vector& keys, + const std::vector>& embeddings, + const std::vector>& optimizerSlots) override; + + int BackUpTrainStatus(const std::string& tableName) override; + + int RecoverTrainStatus(const std::string& tableName) override; + + int ResetOffsetMappers() override; + + uint32_t GetUsage(const std::string& tableName) override; + +private: + std::map embCacheInfos; + std::map offsetMappers; + std::map offsetMappersBackUp; + std::map embTables; + std::map>> kvVecsBackUp; + + int CheckValidTableName(const std::string& tableName); + + bool CheckInitializer(uint32_t extEmbSize, std::vector initializerInfos); + + bool CheckValidThreadNum(uint32_t threadNum); + + int CheckGetSwapPairsAndKey2Offset(const std::string& tableName, const KeyOffsetPair& swapInKoPair, + const KeyOffsetPair& swapOutKoPair); + + int CheckCreateTableName(const std::string& tableName); +}; +} // namespace EmbCache +#endif // EMBEDDING_CACHE_MANAGER_H diff --git a/src/AccCTR/src/embedding_cache/common.h b/src/AccCTR/src/embedding_cache/common.h new file mode 100644 index 0000000000000000000000000000000000000000..d9841541d99b40088a1d3ab5a9092e82602f2a2e --- /dev/null +++ b/src/AccCTR/src/embedding_cache/common.h @@ -0,0 +1,66 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_COMMON_H +#define MXREC_COMMON_H + +#include "limited_set.h" + +#ifndef HM_UNLIKELY +#define HM_UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif + +#ifndef HM_LIKELY +#define HM_LIKELY(x) __builtin_expect(!!(x), 1) +#endif + +namespace EmbCache { + + +enum class FkvState { + FKV_EXIST, + FKV_NOT_EXIST, + FKV_KEY_CONFLICT, + FKV_BEFORE_PUT_FUNC_FAIL, + FKV_BEFORE_REMOVE_FUNC_FAIL, + FKV_NO_SPACE, + FKV_FAIL, +}; + +enum class BeforePutFuncState { + BEFORE_SUCCESS, + BEFORE_NO_SPACE, + BEFORE_FAIL, +}; + +enum class BeforeRemoveFuncState { + BEFORE_SUCCESS, + BEFORE_FAIL, +}; + +extern int64_t INVALID_KEY; +constexpr uint64_t TABLE_NAME_MAX_SIZE = 1024; +const uint32_t VOCAB_CACHE_RATIO = 15; +constexpr float NORMAL_MEAN_MAX = 1e9; +constexpr float NORMAL_MEAN_MIN = -1e9; +constexpr float NORMAL_STDDEV_MAX = 100; +constexpr float NORMAL_STDDEV_MIN = 0; +constexpr float CONSTANT_VALUE_MAX = 1e9; +constexpr float CONSTANT_VALUE_MIN = -1e9; +constexpr float INIT_K_MAX = 10000; +constexpr float INIT_K_MIN = -10000; +const int INVALID_EMB_SIZE = -1; +const size_t MEMSET_S_MAX_SIZE = 2LL * 1024 * 1024 * 1024 - 1; +} +#endif // MXREC_COMMON_H diff --git a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dc59a303670458128a2e4e577339b9fcc080ab08 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp @@ -0,0 +1,475 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "emb_local_table.h" + +#include +#include + +#include "error_code.h" +#include "securec.h" + +using namespace std; +using namespace EmbCache; +using namespace ock; +using namespace ock::ctr; + +bool EmbLocalTable::Initialize(const EmbCacheInfo& embCacheInfo, uint64_t reserve, + const std::vector& initializerInfos, const EmbPoolParam& embPoolParam) +{ + emExpendMemInfo = make_shared(embPoolParam.prefillBufferSize, initializerInfos, + embCacheInfo.extEmbeddingSize, embCacheInfo.vocabSize, + embPoolParam.refillThreadNum); + embeddingSize = embCacheInfo.embeddingSize; + extEmbeddingSize = embCacheInfo.extEmbeddingSize; + return embMap.Initialize(reserve, embCacheInfo.vocabSize, emExpendMemInfo); +} + +void EmbLocalTable::UnInitialize() +{ + embMap.UnInitialize(); +} + +int EmbLocalTable::FindAndPutIfNotFound(uint64_t key, uint64_t& value) +{ + FkvState ret = embMap.FindAndPutIfNotFound(key, value); + if (ret == FkvState::FKV_FAIL) { + return H_ERROR; + } + if (ret == FkvState::FKV_BEFORE_PUT_FUNC_FAIL) { + return H_MEMORY_ALLOC_ERROR; + } + if (ret == FkvState::FKV_NO_SPACE) { + return H_HOST_VOCAB_SIZE_TOO_SMALL; + } + return H_OK; +} + +bool EmbLocalTable::Remove(uint64_t key) +{ + return embMap.Remove(key) != FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; +} + +int EmbLocalTable::RemoveByKeys(const std::vector& keys, uint32_t threadNum) +{ + if (threadNum == 1) { + for (uint64_t key : keys) { + if (!Remove(key)) { + return H_ERROR; + } + } + return H_OK; + } + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector> threads(threadNum); + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = std::async(std::launch::async, [&, threadId]() { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + if (!Remove(keys[i])) { + return H_ERROR; + } + } + return H_OK; + }); + } + for (auto& t : threads) { + auto res = t.get(); + if (res != H_OK) { + return res; + } + } + return H_OK; +} + +int EmbLocalTable::OneThreadHandle(uint64_t startAddr, const std::vector& keys, bool isGather) +{ + for (uint64_t i = 0; i < keys.size(); i++) { + uint64_t embAddr; + int ret = FindAndPutIfNotFound(keys[i], embAddr); + if (ret != H_OK) { + return ret; + } + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + if (isGather) { + auto rc = memcpy_s(reinterpret_cast(addr), memSize, reinterpret_cast(embAddr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "gather memcpy_s failed... dstSize: " + std::to_string(memSize)); + return H_COPY_ERROR; + } + } else { + auto rc = memcpy_s(reinterpret_cast(embAddr), memSize, // 按顺序把新的embedding拷贝到对应地址中 + reinterpret_cast(addr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "scatter memcpy_s failed... dstSize: " + std::to_string(memSize)); + return H_COPY_ERROR; + } + } + } + + return H_OK; +} + +int EmbLocalTable::Gather(uint64_t startAddr, const vector& keys, uint32_t threadNum) +{ + if (threadNum == 1) { + return OneThreadHandle(startAddr, keys, true); + } + + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + int ret = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + uint64_t embAddr; + int temp_ret = FindAndPutIfNotFound(keys[i], embAddr); + if (temp_ret != H_OK) { + ret = temp_ret; + return; + } + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto rc = memcpy_s(reinterpret_cast(addr), memSize, reinterpret_cast(embAddr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + ret = H_COPY_ERROR; + return; + } + } + }); + } + for (auto& t : threads) { + t.join(); + } + return ret; +} + +int EmbLocalTable::GatherAddrs(const std::vector& keys, std::vector& addrs, uint32_t threadNum) +{ + if (threadNum == 1) { + addrs.resize(keys.size()); + for (uint64_t i = 0; i < keys.size(); i++) { + int temp_ret = FindAndPutIfNotFound(keys[i], reinterpret_cast(addrs[i])); + if (temp_ret != H_OK) { + return temp_ret; + } + } + return H_OK; + } + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + addrs.resize(keys.size()); + + vector threads(threadNum); + int ret = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + int temp_ret = FindAndPutIfNotFound(keys[i], reinterpret_cast(addrs[i])); + if (temp_ret != H_OK) { + ret = temp_ret; + return; + } + } + }); + } + for (auto& t : threads) { + t.join(); + } + return ret; +} + +// 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 +int EmbLocalTable::GatherAndRemove(uint64_t startAddr, const vector& keys, uint32_t threadNum) +{ + if (threadNum == 1) { + for (uint64_t i = 0; i < keys.size(); i++) { + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto ret = embMap.FindAndRemoveIfFound(keys[i], addr); // 如果找到了就拷贝出来然后把key删了 + if (ret == FkvState::FKV_NOT_EXIST) { // 没找到key,给一个新的初始化值并且不需要存入key + auto* embAddr = reinterpret_cast(addr); + for (const auto& initializerInfo : emExpendMemInfo->initializerInfos) { + initializerInfo.initializer->GenerateData(embAddr, INVALID_EMB_SIZE); + } + } else if (ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + return H_COPY_ERROR; + } + } + return H_OK; + } + + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + int retVal = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto ret = embMap.FindAndRemoveIfFound(keys[i], addr); // 如果找到了就拷贝出来然后把key删了 + if (ret == FkvState::FKV_NOT_EXIST) { // 没找到key,给一个新的初始化值并且不需要存入key + auto* embAddr = reinterpret_cast(addr); + for (const auto& initializerInfo : emExpendMemInfo->initializerInfos) { + initializerInfo.initializer->GenerateData(embAddr, INVALID_EMB_SIZE); + } + } else if (ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + retVal = H_COPY_ERROR; + return; + } + } + }); + } + for (auto& t : threads) { + t.join(); + } + return retVal; +} + +int EmbLocalTable::Scatter(const uint64_t startAddr, const vector& keys, uint32_t threadNum) +{ + if (threadNum == 1) { // 单线程版本 + return OneThreadHandle(startAddr, keys, false); + } + + // 多线程版本 + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + int ret = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + uint64_t embAddr; + int temp_ret = FindAndPutIfNotFound(keys[i], embAddr); // 获取每个key的embedding对应首地址 + if (temp_ret != H_OK) { + ret = temp_ret; + return; + } + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto rc = memcpy_s(reinterpret_cast(embAddr), memSize, // 按顺序把新的embedding拷贝到对应地址中 + reinterpret_cast(addr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + ret = H_COPY_ERROR; + return; + } + } + }); + } + for (auto& t : threads) { + t.join(); + } + return ret; +} + +// 导出存储的所有kv对 +vector> EmbLocalTable::ExportVec() +{ + return embMap.ExportVec(); +} + +template +void EmbLocalTable::insertData(vector& buffer, T& data) +{ + buffer.insert(buffer.end(), (char*)&data, (char*)&data + sizeof(data)); +} + +template +bool EmbLocalTable::getData(const vector& buffer, T& data, uint64_t& i) +{ + if (i + sizeof(T) > buffer.size()) { + return false; + } + data = *reinterpret_cast(&buffer[i]); + i += sizeof(T); + return true; +} + +// 把所存储的key-embedding信息序列化 +vector EmbLocalTable::Serialize() +{ + vector buffer; + vector> kvVec = ExportVec(); + + for (auto& p : kvVec) { + uint64_t key = p.first; + uint64_t value = p.second; + insertData(buffer, key); + auto* addr = reinterpret_cast(value); + buffer.insert(buffer.end(), reinterpret_cast(addr), + reinterpret_cast((addr + emExpendMemInfo->extEmbeddingSize))); + } + return buffer; +} + +// 反序列化key-embedding,存进map +bool EmbLocalTable::Deserialize(const vector& buffer) +{ + uint64_t i = 0; + while (i < buffer.size()) { + uint64_t key; + if (!getData(buffer, key, i)) { + ExternalLogger::PrintLog(LogLevel::ERROR, "get data failed!"); + return false; + } + uint64_t value = 0; + if (FindAndPutIfNotFound(key, value) != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "FindAndPutIfNotFound failed!"); + return false; + } + + auto* addr = reinterpret_cast(value); + for (uint32_t j = 0; j < emExpendMemInfo->extEmbeddingSize; j++) { + if (!getData(buffer, addr[j], i)) { + ExternalLogger::PrintLog(LogLevel::ERROR, "get data failed!"); + return false; + } + } + } + return true; +} + +uint32_t EmbLocalTable::GetUsage() +{ + return embMap.current_size; +} + +void EmbLocalTable::GetEmbTableInfos(std::vector& keys, std::vector>& embeddings, + std::vector>& optimizerSlots) +{ + vector> kvVec = ExportVec(); + + for (auto& p : kvVec) { + std::vector curEmbedding; + keys.emplace_back(p.first); + auto* addr = reinterpret_cast(p.second); + curEmbedding.insert(curEmbedding.end(), addr, reinterpret_cast((addr + embeddingSize))); + embeddings.emplace_back(curEmbedding); + if (extEmbeddingSize > embeddingSize) { + std::vector curOptimizerSlot; + curOptimizerSlot.insert(curOptimizerSlot.end(), reinterpret_cast(addr + embeddingSize), + reinterpret_cast((addr + extEmbeddingSize))); + optimizerSlots.emplace_back(curOptimizerSlot); + } + } +} + +bool EmbLocalTable::LoadEmbTableInfos(const std::vector& keys, + const std::vector>& embeddings, + const std::vector>& optimizerSlots) +{ + if (keys.size() != embeddings.size()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "the size of keys and embeddings should be same!"); + return false; + } + uint32_t optimizerSlotSize = extEmbeddingSize - embeddingSize; + if (optimizerSlotSize > 0) { + if (keys.size() != optimizerSlots.size()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "the size of keys and optimizerSlots should be same!"); + return false; + } + } + for (uint64_t i = 0; i < keys.size(); i++) { + uint64_t value = 0; + if (FindAndPutIfNotFound(keys[i], value) != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "FindAndPutIfNotFound failed!"); + return false; + } + if (embeddings[i].size() != embeddingSize) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "The size of entering Embedding does not equals to embeddingSize"); + return false; + } + auto* addr = reinterpret_cast(value); + auto rc = memcpy_s(addr, embeddingSize * sizeof(float), embeddings[i].data(), embeddingSize * sizeof(float)); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "embedding memcpy_s failed... "); + return false; + } + if (optimizerSlotSize > 0) { + if (optimizerSlots[i].size() != optimizerSlotSize) { + ExternalLogger::PrintLog( + LogLevel::ERROR, + "The size of entering optimizerSlot does not equals to extEmbeddingSize - embeddingSize"); + return false; + } + auto rc2 = memcpy_s(reinterpret_cast(addr + embeddingSize), optimizerSlotSize * sizeof(float), + optimizerSlots[i].data(), optimizerSlotSize * sizeof(float)); + if (rc2 != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "optimizerSlot memcpy_s failed... "); + return false; + } + } + } + return true; +} \ No newline at end of file diff --git a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h new file mode 100644 index 0000000000000000000000000000000000000000..ee93bb91f76a89df4c3f9e35629a66518f0ad8f7 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h @@ -0,0 +1,84 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef EMB_LOCAL_TABLE_H +#define EMB_LOCAL_TABLE_H + +#include +#include +#include + +#include "offset_mapper/address_mapper.h" + +namespace EmbCache { +struct EmbPoolParam { + uint64_t prefillBufferSize; + uint32_t refillThreadNum; +}; + +class EmbLocalTable { +public: + EmbLocalTable() = default; + + ~EmbLocalTable() = default; + + bool Initialize(const EmbCacheInfo& embCacheInfo, uint64_t reserve, + const std::vector& initializerInfos, const EmbPoolParam& embPoolParam); + + void UnInitialize(); + + int FindAndPutIfNotFound(uint64_t key, uint64_t& value); + + bool Remove(uint64_t key); + + int RemoveByKeys(const std::vector& keys, uint32_t threadNum); + + int Gather(uint64_t startAddr, const std::vector& keys, uint32_t threadNum); + + int GatherAddrs(const std::vector& keys, std::vector& addrs, uint32_t threadNum); + + int Scatter(uint64_t startAddr, const std::vector& keys, uint32_t threadNum); + + int OneThreadHandle(uint64_t startAddr, const std::vector& keys, bool isGather); + + int GatherAndRemove(uint64_t startAddr, const std::vector& keys, uint32_t threadNum); + + std::vector> ExportVec(); + + std::vector Serialize(); + + bool Deserialize(const std::vector& buffer); + + uint32_t GetUsage(); + + void GetEmbTableInfos(std::vector& keys, std::vector>& embeddings, + std::vector>& optimizerSlots); + + bool LoadEmbTableInfos(const std::vector& keys, const std::vector>& embeddings, + const std::vector>& optimizerSlots); + +private: + std::shared_ptr emExpendMemInfo; + AddressMapper embMap; + uint32_t embeddingSize; + uint32_t extEmbeddingSize; + + template + void insertData(std::vector& buffer, T& data); + + template + bool getData(const std::vector& buffer, T& data, uint64_t& i); +}; +} // namespace EmbCache +#endif // EMB_LOCAL_TABLE_H diff --git a/src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e0ecb0df59df343ad9d9cc8f47d485141a93d67 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp @@ -0,0 +1,62 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "embedding_cache.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +using namespace std; +using namespace EmbCache; +using namespace ock; + +ConstantInitializer::ConstantInitializer(uint32_t start, uint32_t len, float value, float initK) + : start(start), len(len) +{ + if (value > CONSTANT_VALUE_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant value is greater than " + + std::to_string(CONSTANT_VALUE_MAX) + ", and will use " + std::to_string(CONSTANT_VALUE_MAX) + "."); + constantValue = CONSTANT_VALUE_MAX; + } else if (value < CONSTANT_VALUE_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant value is less than " + std::to_string(CONSTANT_VALUE_MIN) + + ", and will use " + std::to_string(CONSTANT_VALUE_MIN) + "."); + constantValue = CONSTANT_VALUE_MIN; + } else { + constantValue = value; + } + if (initK > INIT_K_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant initK is greater than " + std::to_string(INIT_K_MAX) + + ", and will use " + std::to_string(INIT_K_MAX) + "."); + initParam = INIT_K_MAX; + } else if (initK < INIT_K_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant initK is less than " + std::to_string(INIT_K_MIN) + + ", and will use " + std::to_string(INIT_K_MIN) + "."); + initParam = INIT_K_MIN; + } else { + initParam = initK; + } +} + +void ConstantInitializer::GenerateData(float* emb, int embSize) +{ + if (len == 0) { + return; + } + if (embSize != INVALID_EMB_SIZE && embSize < static_cast(start + len)) { + ExternalLogger::PrintLog(LogLevel::WARN, + "InitializeInfo start " + std::to_string(start) + " + len " + std::to_string(len) + + " is larger than embedding size " + std::to_string(embSize)); + return; + } + std::fill_n(emb + start, len, initParam * constantValue); +} diff --git a/src/AccCTR/src/embedding_cache/initializer/initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/initializer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..887aaee046cf1ddecafe2f36744d803a0841d12c --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/initializer.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include + +#include "external_logger.h" +#include "embedding_cache.h" + +using namespace EmbCache; + +ConstantInitializerInfo::ConstantInitializerInfo(float constantValue, float initK) + : constantValue(constantValue), initK(initK) +{} + +NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, uint32_t seed, float initK) + : mean(mean), stddev(stddev), seed(seed), initK(initK) +{} + +InitializerInfo::InitializerInfo(std::string &name, uint32_t start, uint32_t len, + ConstantInitializerInfo constantInitializerInfo) + : name(name), start(start), len(len), constantInitializerInfo(constantInitializerInfo) +{ + if (name == "constant_initializer") { + initializerType = InitializerType::CONSTANT; + initializer = std::make_shared(start, len, constantInitializerInfo.constantValue, + constantInitializerInfo.initK); + } else { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Invalid Initializer Type."); + } +} + +InitializerInfo::InitializerInfo(std::string &name, uint32_t start, uint32_t len, + NormalInitializerInfo normalInitializerInfo) + : name(name), start(start), len(len), normalInitializerInfo(normalInitializerInfo) +{ + if (name == "truncated_normal_initializer") { + initializerType = InitializerType::TRUNCATED_NORMAL; + initializer = std::make_shared(start, len, normalInitializerInfo); + } else if (name == "random_normal_initializer") { + initializerType = InitializerType::RANDOM_NORMAL; + initializer = std::make_shared(start, len, normalInitializerInfo); + } else { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Invalid Initializer Type."); + } +} diff --git a/src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4b01062ffc3f478a22fb64c048fbc3a84e52b0a --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -0,0 +1,78 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include +#include "embedding_cache.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +using namespace EmbCache; +using namespace ock; + +RandomNormalInitializer::RandomNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo &initInfo) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) +{ + // 校验stddev mean及initK值范围 + if (initInfo.mean > NORMAL_MEAN_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal mean param is greater than " + + std::to_string(NORMAL_MEAN_MAX) + ", and will use " + std::to_string(NORMAL_MEAN_MAX) + "."); + mean = NORMAL_MEAN_MAX; + } else if (initInfo.mean < NORMAL_MEAN_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal mean param is less than " + + std::to_string(NORMAL_MEAN_MIN) + ", and will use " + std::to_string(NORMAL_MEAN_MIN) + "."); + mean = NORMAL_MEAN_MIN; + } else { + mean = initInfo.mean; + } + if (initInfo.stddev > NORMAL_STDDEV_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal stddev param is greater than " + + std::to_string(NORMAL_STDDEV_MAX) + ", and will use " + std::to_string(NORMAL_STDDEV_MAX) + "."); + stddev = NORMAL_STDDEV_MAX; + } else if (initInfo.stddev < NORMAL_STDDEV_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal stddev param is less than " + + std::to_string(NORMAL_STDDEV_MIN) + ", and will use " + std::to_string(NORMAL_STDDEV_MIN) + "."); + stddev = NORMAL_STDDEV_MIN; + } else { + stddev = initInfo.stddev; + } + if (initInfo.initK > INIT_K_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal initK is greater than " + std::to_string(INIT_K_MAX) + + ", and will use " + std::to_string(INIT_K_MAX) + "."); + initParam = INIT_K_MAX; + } else if (initInfo.initK < INIT_K_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal initK is less than " + std::to_string(INIT_K_MIN) + + ", and will use " + std::to_string(INIT_K_MIN) + "."); + initParam = INIT_K_MIN; + } else { + initParam = initInfo.initK; + } + + generator = std::default_random_engine(seed); + distribution = std::normal_distribution(mean, stddev); +} + +void RandomNormalInitializer::GenerateData(float* emb, int embSize) +{ + if (len == 0) { + return; + } + if (embSize != INVALID_EMB_SIZE && embSize < static_cast(start + len)) { + ExternalLogger::PrintLog(LogLevel::WARN, + "InitializeInfo start " + std::to_string(start) + " + len " + std::to_string(len) + + " is larger than embedding size " + std::to_string(embSize)); + return; + } + std::generate_n(emb + start, len, [this]() { return initParam * distribution(generator); }); +} \ No newline at end of file diff --git a/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95e097575df3c5c5fb837dde8bc1b250d6942f85 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -0,0 +1,94 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include "embedding_cache.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +using namespace EmbCache; +using namespace ock; + +TruncatedNormalInitializer::TruncatedNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo &initInfo) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) +{ + // 校验stddev mean及initK值范围 + if (initInfo.mean > NORMAL_MEAN_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal mean param is greater than " + + std::to_string(NORMAL_MEAN_MAX) + ", and will use " + std::to_string(NORMAL_MEAN_MAX) + "."); + mean = NORMAL_MEAN_MAX; + } else if (initInfo.mean < NORMAL_MEAN_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal mean param is less than " + + std::to_string(NORMAL_MEAN_MIN) + ", and will use " + std::to_string(NORMAL_MEAN_MIN) + "."); + mean = NORMAL_MEAN_MIN; + } else { + mean = initInfo.mean; + } + + if (initInfo.stddev > NORMAL_STDDEV_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal stddev param is greater than " + + std::to_string(NORMAL_STDDEV_MAX) + ", and will use " + std::to_string(NORMAL_STDDEV_MAX) + "."); + stddev = NORMAL_STDDEV_MAX; + } else if (initInfo.stddev < NORMAL_STDDEV_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal stddev param is less than " + + std::to_string(NORMAL_STDDEV_MIN) + ", and will use " + std::to_string(NORMAL_STDDEV_MIN) + "."); + stddev = NORMAL_STDDEV_MIN; + } else { + stddev = initInfo.stddev; + } + + if (abs(stddev) < std::numeric_limits::epsilon()) { + ExternalLogger::PrintLog( + LogLevel::WARN, + "truncated normal stddev param is zero, initialization can be slow, suggest using constant initializer"); + } + + if (initInfo.initK > INIT_K_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal initK is greater than " + + std::to_string(INIT_K_MAX) + ", and will use " + std::to_string(INIT_K_MAX) + "."); + initParam = INIT_K_MAX; + } else if (initInfo.initK < INIT_K_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal initK is less than " + std::to_string(INIT_K_MIN) + + ", and will use " + std::to_string(INIT_K_MIN) + "."); + initParam = INIT_K_MIN; + } else { + initParam = initInfo.initK; + } + + generator = std::default_random_engine(seed); + distribution = std::normal_distribution(mean, stddev); + minBound = initParam * (mean - static_cast(boundNum) * stddev); + maxBound = initParam * (mean + static_cast(boundNum) * stddev); +} + + +void TruncatedNormalInitializer::GenerateData(float* emb, int embSize) +{ + if (len == 0) { + return; + } + if (embSize != INVALID_EMB_SIZE && embSize < static_cast(start + len)) { + ExternalLogger::PrintLog(LogLevel::WARN, + "InitializeInfo start " + std::to_string(start) + " + len " + std::to_string(len) + + " is larger than embedding size " + std::to_string(embSize)); + return; + } + std::generate_n(emb + start, len, [this]() { + float tmp = initParam * distribution(generator); + while (tmp < minBound || tmp > maxBound) { + tmp = initParam * distribution(generator); + } + return tmp; + }); +} diff --git a/src/AccCTR/src/embedding_cache/limited_set.h b/src/AccCTR/src/embedding_cache/limited_set.h new file mode 100644 index 0000000000000000000000000000000000000000..f7bc2e1e6fac570becffce4cd772da036afc285c --- /dev/null +++ b/src/AccCTR/src/embedding_cache/limited_set.h @@ -0,0 +1,135 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_LIMITED_SET_H +#define MXREC_LIMITED_SET_H + +#include +#include + +namespace EmbCache { + +static constexpr int64_t NODE_DEFAULT_VALUE = -1; + +class LimitedSet { +public: + struct Node { + uint64_t value; + Node *prev, *next; + Node(uint64_t val = NODE_DEFAULT_VALUE) : value(val), prev(nullptr), next(nullptr) {} + }; + + LimitedSet(uint64_t maxRange) : head(new Node(NODE_DEFAULT_VALUE)), tail(new Node(NODE_DEFAULT_VALUE)) + { + nodes.resize(maxRange); + for (auto &node : nodes) { + node = new Node(NODE_DEFAULT_VALUE); + } + head->next = tail; + tail->prev = head; + } + + ~LimitedSet() + { + for (auto &node : nodes) { + delete node; + } + delete head; + delete tail; + } + + LimitedSet(const LimitedSet& other): head(new Node(NODE_DEFAULT_VALUE)), tail(new Node(NODE_DEFAULT_VALUE)) + { + nodes.resize(other.nodes.size()); + for (auto& node: nodes) { + node = new Node(NODE_DEFAULT_VALUE); + } + + head->next = tail; + tail->prev = head; + + for (Node* node = other.head->next; node != other.tail; node = node->next) { + insert(node->value); + } + } + + void insert(uint64_t value) + { + if (nodes[value]->value == value) { + return; + } + Node *node = nodes[value]; + node->value = value; + Node *next = head->next; + node->next = next; + node->prev = head; + head->next = node; + next->prev = node; + } + + void remove(uint64_t value) + { + if (nodes[value]->value != value) { + return; + } + Node *node = nodes[value]; + node->prev->next = node->next; + node->next->prev = node->prev; + node->value = NODE_DEFAULT_VALUE; + } + + bool find(uint64_t value) + { + return nodes[value]->value == value; + } + + class Iterator { + public: + Iterator(Node *node) : current(node) {} + bool operator != (const Iterator &other) const + { + return current != other.current; + } + const uint64_t &operator*() const + { + return current->value; + } + Iterator &operator ++ () + { + current = current->next; + return *this; + } + + private: + Node *current; + }; + + Iterator begin() + { + return { head->next }; + } + + Iterator end() + { + return { tail }; + } + +private: + Node *head; + Node *tail; + std::vector nodes; +}; + +} +#endif // MXREC_LIMITED_SET_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h new file mode 100644 index 0000000000000000000000000000000000000000..8b7e4e67f9e900a7dfde4c8c065a8e66d7384ce8 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h @@ -0,0 +1,309 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_FASTER_QUERY_H +#define MXREC_FASTER_QUERY_H + +#include +#include +#include +#include +#include +#include +#include + +#include "embedding_cache.h" +#include "offset_mapper/mapper_base.h" +#include "securec.h" + +namespace EmbCache { +using EmExpandMemUint = struct em_expand_memory_uint_ { + uint64_t address = 0; + uint64_t capacity = 0; + uint64_t leftCapacity = 0; + + em_expand_memory_uint_() = default; + + em_expand_memory_uint_(uint64_t a, uint64_t c) : address(a), capacity(c), leftCapacity(c) {} +}; + +template +class QWithLock { +public: + bool pop(T& ele) + { + std::lock_guard lk(mut); + if (dataQ.empty()) { + return false; + } + ele = dataQ.front(); + dataQ.pop(); + return true; + } + + void push(const T& ele) + { + std::lock_guard lk(mut); + dataQ.push(ele); + } + + uint64_t GetLength() + { + std::lock_guard lk(mut); + return dataQ.size(); + } + +private: + std::mutex mut; + std::queue dataQ; +}; + +class AutoRefillEmbeddingMemoryPool { +public: + std::vector expandedMemory; + uint32_t extEmbeddingSize; + std::vector initializerInfos; + + AutoRefillEmbeddingMemoryPool(uint64_t bufferSize, std::vector initInfos, uint32_t extEmbSize, + uint64_t hostVocabSize, uint32_t refillThreadNum = 1) + : extEmbeddingSize(extEmbSize), + initializerInfos(std::move(initInfos)), + maxBufferSize(bufferSize), + totalLeftVocabSize(hostVocabSize), + numThreads(refillThreadNum) + { + itemSize = extEmbeddingSize * sizeof(float); + maxExpandSize = maxBufferSize * itemSize; + for (uint32_t i = 0; i < numThreads; i++) { + producerThreads.emplace_back([this] { ProducerWorker(); }); + } + } + + ~AutoRefillEmbeddingMemoryPool() + { + stop = true; + std::lock_guard lock(producerMutex); + producerCv.notify_all(); + fullCv.notify_all(); + for (auto& t : producerThreads) { + t.join(); + } + } + + void Stop() + { + stop = true; + std::lock_guard lock(producerMutex); + producerCv.notify_all(); + fullCv.notify_all(); + } + + BeforePutFuncState GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry = 1000) + { + for (uint32_t i = 0; i < maxRetry; i++) { + if (BufferBin.pop(value)) { + producerCv.notify_one(); + return BeforePutFuncState::BEFORE_SUCCESS; + }; + producerCv.notify_one(); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ock::ExternalLogger::PrintLog( + ock::LogLevel::ERROR, + "Failed to get new address for embedding, it is likely due to refill thread memory allocation failure " + "or max retry has been reached. Please check for memory alloc error or increase refill thread num!"); + return BeforePutFuncState::BEFORE_FAIL; + } + + void GetValueToBeRecycled(uint64_t value) + { + std::lock_guard lock(producerMutex); + recycleBin.push(value); + full = false; + fullCv.notify_one(); + } + +private: + uint64_t maxBufferSize; + uint64_t totalLeftVocabSize; + uint32_t numThreads; + std::atomic currBufferSize{0}; + volatile std::atomic stop = false; + volatile std::atomic full = false; + std::mutex producerMutex; + std::mutex getAddrMutex; + std::condition_variable producerCv; + std::condition_variable fullCv; + QWithLock BufferBin; + QWithLock recycleBin; + std::vector producerThreads; + EmExpandMemUint currentMemoryUint{}; + uint64_t dynamicExpandRatio = 2; + uint64_t maxExpandSize; + uint64_t itemSize; + + bool GetNewAddr(uint64_t& newAddr) + { + std::lock_guard lg(getAddrMutex); + if (HM_UNLIKELY(currentMemoryUint.leftCapacity <= 0)) { + /* need to expand memory */ + uint64_t maxSize = std::min(maxExpandSize, totalLeftVocabSize * itemSize); + uint64_t newSize = currentMemoryUint.capacity + ? std::min(currentMemoryUint.capacity * dynamicExpandRatio, maxSize) + : itemSize; + if (newSize == 0) { + if (recycleBin.GetLength() == 0) { + full = true; + } + return false; + } + auto newAddress = (uint64_t)malloc(newSize); + if (newAddress == 0) { + ock::ExternalLogger::PrintLog(ock::LogLevel::WARN, "Refill thread allocate memory failed!"); + return false; + } + expandedMemory.emplace_back(newAddress, newSize); + currentMemoryUint.address = newAddress; + currentMemoryUint.capacity = newSize; + currentMemoryUint.leftCapacity = newSize; + totalLeftVocabSize -= newSize / itemSize; + } + newAddr = currentMemoryUint.address + currentMemoryUint.capacity - currentMemoryUint.leftCapacity; + currentMemoryUint.leftCapacity -= itemSize; + return true; + } + + void Produce() + { + uint64_t newAddr; + if (!recycleBin.pop(newAddr)) { + if (!GetNewAddr(newAddr)) { + return; + } + } + GenerateData(newAddr); + BufferBin.push(newAddr); + } + + void GenerateData(const uint64_t& addr) + { + auto* embAddr = reinterpret_cast(addr); + for (const auto& initializerInfo : initializerInfos) { + initializerInfo.initializer->GenerateData(embAddr, INVALID_EMB_SIZE); + } + } + + void ProducerWorker() + { + std::unique_lock lock(producerMutex); + while (!stop) { + if (full) { + fullCv.wait(lock); + continue; + } + if (BufferBin.GetLength() < maxBufferSize) { + Produce(); + continue; + } + producerCv.wait(lock); + } + } +}; + +class AddressMapper : public MapperBase { +public: + AddressMapper() = default; + + ~AddressMapper() = default; + + bool Initialize(uint32_t reserve, uint32_t vocabSize, std::shared_ptr expendInfoPtr) + { + hostVocabSize = vocabSize; + emExpendMemInfoPtr = expendInfoPtr; + return MapperBase::Initialize(reserve); + } + + void UnInitialize() override + { + emExpendMemInfoPtr->Stop(); + FreeExpandedMemory(); + MapperBase::UnInitialize(); + } + + FkvState Remove(uint64_t key) + { + return MapperBase::Remove(key, [&](uint64_t value) { + emExpendMemInfoPtr->GetValueToBeRecycled(value); + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + } + + FkvState FindAndPutIfNotFound(uint64_t key, uint64_t& value) + { + FkvState ret = MapperBase::FindAndPutIfNotFound(key, value, [&]() { + if (HM_UNLIKELY(current_size.load() >= hostVocabSize)) { + ock::ExternalLogger::PrintLog( + ock::LogLevel::ERROR, + "host does not have enough space, current: " + std::to_string(current_size.load()) + + ", host max size: " + std::to_string(hostVocabSize)); + return BeforePutFuncState::BEFORE_NO_SPACE; + } + return emExpendMemInfoPtr->GetNewValueToBeInserted(value); + }); + if (ret == FkvState::FKV_FAIL) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "FindAndPutIfNotFound failed!"); + return ret; + } + if (ret == FkvState::FKV_BEFORE_PUT_FUNC_FAIL) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "malloc failed"); + return ret; + } + return ret; + } + + // 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 + FkvState FindAndRemoveIfFound(uint64_t key, const uint64_t startAddr) + { + return MapperBase::Remove(key, [&](uint64_t value) { + uint64_t memSize = emExpendMemInfoPtr->extEmbeddingSize * sizeof(float); + auto rc = memcpy_s(reinterpret_cast(startAddr), memSize, reinterpret_cast(value), memSize); + if (rc != 0) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "memcpy_s failed... dstSize: " + std::to_string(memSize)); + return BeforeRemoveFuncState::BEFORE_FAIL; + } + emExpendMemInfoPtr->GetValueToBeRecycled(value); + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + } + + uint32_t GetUsage() + { + return MapperBase::current_size; + } + +private: + void FreeExpandedMemory() + { + for (auto& memUint : emExpendMemInfoPtr->expandedMemory) { + free(reinterpret_cast(memUint.address)); + } + } + +private: + uint32_t hostVocabSize; + std::shared_ptr emExpendMemInfoPtr; +}; +} // namespace EmbCache +#endif // MXREC_FASTER_QUERY_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h new file mode 100644 index 0000000000000000000000000000000000000000..42d62ca4f576c88da8c330081bdb3c274c2dfb55 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -0,0 +1,831 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_MAPPER_BASE_H +#define MXREC_MAPPER_BASE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "securec.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +namespace EmbCache { +/* + * @brief Allocator template, for extend memory allocation for overflowed buckets + */ + +static constexpr size_t K_ALIGNMENT = 64; +static constexpr size_t K_KVNUMINBUCKET = 3; + +enum BucketIdx { + FIRST, + SECOND, + THIRD +}; + +class NetHeapAllocator { +public: + void *Allocate(uint32_t size) + { + return calloc(1, size); + } + + void Free(void *p) + { + if (HM_LIKELY(p != nullptr)) { + free(p); + p = nullptr; + } + } +}; + +/* + * @brief Spin lock entry in bucket + * used for alloc overflowed buckets + */ + +struct NetHashLockEntry { + uint64_t lock = 0; + + /* + * @brief Spin lock + */ + void Lock() + { + while (!__sync_bool_compare_and_swap(&lock, 0, 1)) { + } + } + + /* + * @brief Unlock + */ + void UnLock() + { + __atomic_store_n(&lock, 0, __ATOMIC_SEQ_CST); + } +} __attribute__((packed)); + +/* + * @brief Store the key/value into a linked array with 6 items, + * because 64bytes is one cache line + */ + +struct alignas(K_ALIGNMENT)NetHashBucket { + std::atomic keys[K_KVNUMINBUCKET]{}; + uint64_t values[K_KVNUMINBUCKET]{}; + NetHashBucket *next = nullptr; + NetHashLockEntry spinLock{}; + + FkvState Put(uint64_t key, uint64_t &value, const std::function &beforePutFunc) + { + /* don't put them into loop, flat code is faster than loop */ + uint64_t oldKey = 0; + if (keys[BucketIdx::FIRST].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::FIRST].compare_exchange_strong(oldKey, key)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + keys[BucketIdx::FIRST] = 0; + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + keys[BucketIdx::FIRST] = 0; + return FkvState::FKV_NO_SPACE; + } + values[BucketIdx::FIRST] = value; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(oldKey == key)) { + return FkvState::FKV_KEY_CONFLICT; + } + + oldKey = 0; + if (keys[BucketIdx::SECOND].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::SECOND].compare_exchange_strong(oldKey, key)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + keys[BucketIdx::SECOND] = 0; + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + keys[BucketIdx::SECOND] = 0; + return FkvState::FKV_NO_SPACE; + } + values[BucketIdx::SECOND] = value; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(oldKey == key)) { + return FkvState::FKV_KEY_CONFLICT; + } + + oldKey = 0; + if (keys[BucketIdx::THIRD].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::THIRD].compare_exchange_strong(oldKey, key)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + keys[BucketIdx::THIRD] = 0; + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + keys[BucketIdx::THIRD] = 0; + return FkvState::FKV_NO_SPACE; + } + values[BucketIdx::THIRD] = value; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(oldKey == key)) { + return FkvState::FKV_KEY_CONFLICT; + } + + return FkvState::FKV_FAIL; + } + + /* + * @brief Remove the address from the bucket and get size + */ + bool Find(const uint64_t key, uint64_t &value) + { + /* + * expand the loop, instead of put them into a for/while loop for performance + */ + if (key == keys[BucketIdx::FIRST].load(std::memory_order_relaxed)) { + value = values[BucketIdx::FIRST]; + return true; + } + + if (key == keys[BucketIdx::SECOND].load(std::memory_order_relaxed)) { + value = values[BucketIdx::SECOND]; + return true; + } + + if (key == keys[BucketIdx::THIRD].load(std::memory_order_relaxed)) { + value = values[BucketIdx::THIRD]; + return true; + } + + return false; + } + + FkvState Remove(uint64_t key) + { + /* don't put them into loop, flat code is faster than loop */ + uint64_t oldValue = key; + if (keys[BucketIdx::FIRST].load(std::memory_order_relaxed) == key && + keys[BucketIdx::FIRST].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::FIRST] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[BucketIdx::SECOND].load(std::memory_order_relaxed) == key && + keys[BucketIdx::SECOND].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::SECOND] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[BucketIdx::THIRD].load(std::memory_order_relaxed) == key && + keys[BucketIdx::THIRD].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::THIRD] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + + return FkvState::FKV_NOT_EXIST; + } + + FkvState Remove(uint64_t key, const std::function &beforeRemoveFunc) + { + /* don't put them into loop, flat code is faster than loop */ + uint64_t oldValue = key; + if (keys[BucketIdx::FIRST].load(std::memory_order_relaxed) == key && + keys[BucketIdx::FIRST].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::FIRST]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + values[BucketIdx::FIRST] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[BucketIdx::SECOND].load(std::memory_order_relaxed) == key && + keys[BucketIdx::SECOND].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::SECOND]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + values[BucketIdx::SECOND] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[BucketIdx::THIRD].load(std::memory_order_relaxed) == key && + keys[BucketIdx::THIRD].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::THIRD]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + values[BucketIdx::THIRD] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + + return FkvState::FKV_NOT_EXIST; + } +}; + + +class MapperBase { +public: + // DEFINE_RDMA_REF_COUNT_FUNCTIONS + std::atomic current_size{ 0 }; + + MapperBase() = default; + + ~MapperBase() = default; + + bool Initialize(uint32_t reserve) + { + /* already initialized */ + if (mOverflowEntryAlloc != nullptr) { + return true; + } + + /* get proper bucket count */ + uint32_t bucketCount = std::max(reserve, uint32_t(128)); + if (bucketCount > gPrimes[gPrimesCount - 1]) { + bucketCount = gPrimes[gPrimesCount - 1]; + } else { + uint32_t i = 0; + while (i < gPrimesCount && gPrimes[i] < bucketCount) { + i++; + } + bucketCount = gPrimes[i]; + } + + /* allocate buckets for sub-maps */ + for (auto &mSubMap : mSubMaps) { + NetHashBucket* tmp; + if (!NewAndSetBucket(bucketCount, 0, tmp)) { return false;} + mSubMap = tmp; + } + + /* create overflow entry allocator */ + mOverflowEntryAlloc = new (std::nothrow) NetHeapAllocator(); + if (HM_UNLIKELY(mOverflowEntryAlloc == nullptr)) { + FreeSubMaps(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "Failed to new overflow entry allocator, probably out of memory"); + return false; + } + + /* set bucket count */ + mBucketCount = bucketCount; + ock::ExternalLogger::PrintLog(ock::LogLevel::INFO, + "fastKV inited, mBucketCount: " + std::to_string(mBucketCount)); + return true; + } + + virtual void UnInitialize() + { + if (mOverflowEntryAlloc == nullptr) { + return; + } + + /* free overflowed entries firstly */ + FreeOverFlowedEntries(); + + /* free sub map secondly */ + FreeSubMaps(); + + /* free overflow entry at last */ + delete mOverflowEntryAlloc; + mOverflowEntryAlloc = nullptr; + mBucketCount = 0; + } + + FkvState FindAndPutIfNotFound(uint64_t key, uint64_t &value, + const std::function &beforePutFunc) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + value = zeroValue; + return FkvState::FKV_EXIST; + } + if (__sync_bool_compare_and_swap(&zeroInside, false, true)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + return FkvState::FKV_NO_SPACE; + } + zeroValue = value; + current_size++; + return FkvState::FKV_NOT_EXIST; + } + return FkvState::FKV_KEY_CONFLICT; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + while (buck != nullptr) { + buck->spinLock.Lock(); + if (buck->Find(key, value)) { + buck->spinLock.UnLock(); + return FkvState::FKV_EXIST; + } + buck->spinLock.UnLock(); + + if (buck->next != nullptr) { + buck = buck->next; + } else { + break; + } + } + + // did not find, now do put. continue from the last bucket in find + return PutKeyValue(key, value, buck, beforePutFunc); + } + + FkvState Remove(uint64_t key) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + zeroValue = 0; + current_size--; + } + return FkvState::FKV_EXIST; + } + return FkvState::FKV_NOT_EXIST; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + uint64_t value; + while (buck != nullptr) { + if (buck->Find(key, value)) { + buck->Remove(key); + current_size--; + return FkvState::FKV_EXIST; + } + + buck = buck->next; + } + + return FkvState::FKV_NOT_EXIST; + } + + FkvState Remove(uint64_t key, const std::function &beforeRemoveFunc) + { + if (HM_UNLIKELY(key == 0)) { + if (!zeroInside) { + return FkvState::FKV_NOT_EXIST; + } + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + auto ret = beforeRemoveFunc(zeroValue); + if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + zeroValue = 0; + current_size--; + } + return FkvState::FKV_EXIST; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + uint64_t value; + while (buck != nullptr) { + if (buck->Find(key, value)) { + auto ret = buck->Remove(key, beforeRemoveFunc); + if (HM_UNLIKELY(ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + current_size--; + return FkvState::FKV_EXIST; + } + + buck = buck->next; + } + + return FkvState::FKV_NOT_EXIST; + } + + FkvState Put(uint64_t key, uint64_t value) + { + if (HM_UNLIKELY(key == 0)) { + if (__sync_bool_compare_and_swap(&zeroInside, false, true)) { + zeroValue = value; + current_size++; + return FkvState::FKV_NOT_EXIST; + } + return FkvState::FKV_KEY_CONFLICT; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + /* loop all buckets linked */ + while (buck != nullptr) { + if (buck->next != nullptr) { + buck = buck->next; + } else { + break; + } + } + + // did not find, now do put. continue from the last bucket in find + /* try 8192 times */ + for (uint16_t i = 0; i < 8192; i++) { + /* loop all buckets linked */ + while (buck != nullptr) { + /* if there is an entry to put, just break */ + FkvState putRet = buck->Put(key, value, []() -> BeforePutFuncState { return {}; }); + if (putRet == FkvState::FKV_NOT_EXIST) { + current_size++; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_KEY_CONFLICT)) { + return FkvState::FKV_KEY_CONFLICT; + } + /* + * if no next bucket exist, just for break, + * else move to next bucket linked + */ + if (buck->next == nullptr) { + break; + } else { + buck = buck->next; + } + } + + /* + * if not put successfully in existing buckets, allocate a new one + * + * NOTES: just allocate memory, don't access new bucket in the spin lock scope, + * if access new bucket, which could trigger physical memory allocation which + * could trigger page fault, that is quite slow. In this case, spin lock + * could occupy too much CPU + */ + auto &lock = buck->spinLock; + lock.Lock(); + /* if other thread allocated new buck already, unlock and continue */ + if (buck->next != nullptr) { + buck = buck->next; + lock.UnLock(); + continue; + } + + /* firstly entered thread allocate new bucket */ + auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); + if (HM_UNLIKELY(newBuck == nullptr)) { + lock.UnLock(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to allocate new bucket"); + return FkvState::FKV_FAIL; + } + /* link to current buck, set buck to new buck */ + buck->next = newBuck; + buck = newBuck; + + /* unlock */ + lock.UnLock(); + } + return FkvState::FKV_FAIL; + } + + bool Find(const uint64_t key, uint64_t &value) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + value = zeroValue; + return true; + } + return false; + } + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + while (buck != nullptr) { + if (buck->Find(key, value)) { + return true; + } + + buck = buck->next; + } + + return false; + } + + /* When used in muti thread, this function can only be used when keys are uniqued */ + FkvState FindAndDeleteIfFound(const uint64_t key, uint64_t &value, + const std::function &beforeRemoveFunc) + { + if (HM_UNLIKELY(key == 0)) { + if (!zeroInside) { + return FkvState::FKV_NOT_EXIST; + } + value = zeroValue; + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + auto ret = beforeRemoveFunc(zeroValue); + if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + zeroValue = 0; + current_size--; + } + + return FkvState::FKV_EXIST; + } + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + while (buck != nullptr) { + if (buck->Find(key, value)) { + auto ret = buck->Remove(key, beforeRemoveFunc); + if (HM_UNLIKELY(ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + current_size--; + return FkvState::FKV_EXIST; + } + + buck = buck->next; + } + + return FkvState::FKV_NOT_EXIST; + } + + std::vector> ExportVec() + { + std::vector> kvVec; + if (zeroInside) { + kvVec.emplace_back(0, zeroValue); + } + for (auto &mSubMap : mSubMaps) { + for (uint32_t j = 0; j < mBucketCount; j++) { + auto buck = &mSubMap[j]; + ExtractKeyValInBuck(buck, kvVec); + } + } + return kvVec; + } + +protected: + static constexpr uint16_t gSubMapCount = 5; /* count of sub map */ + static constexpr uint32_t gPrimesCount = 256; + + /* make sure the size of this class is 64 bytes, fit into one cache line */ + NetHeapAllocator *mOverflowEntryAlloc = nullptr; /* allocate overflowed entry in one bucket */ + NetHashBucket *mSubMaps[gSubMapCount]{}; /* sub map */ + uint32_t mBucketCount = 0; /* bucket count of each sub map */ + uint32_t mBaseSize = 4096; /* base size */ + bool zeroInside = false; + uint64_t zeroValue = 0; + + const uint32_t gPrimes[gPrimesCount] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, + 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + 97, 103, 109, 113, 127, 137, 139, 149, 157, 167, + 179, 193, 199, 211, 227, 241, 257, 277, 293, 313, + 337, 359, 383, 409, 439, 467, 503, 541, 577, 619, + 661, 709, 761, 823, 887, 953, 1031, 1109, 1193, 1289, + 1381, 1493, 1613, 1741, 1879, 2029, 2179, 2357, 2549, + 2753, 2971, 3209, 3469, 3739, 4027, 4349, 4703, 5087, + 5503, 5953, 6427, 6949, 7517, 8123, 8783, 9497, 10273, + 11113, 12011, 12983, 14033, 15173, 16411, 17749, 19183, + 20753, 22447, 24281, 26267, 28411, 30727, 33223, 35933, + 38873, 42043, 45481, 49201, 53201, 57557, 62233, 67307, + 72817, 78779, 85229, 92203, 99733, 107897, 116731, 126271, + 136607, 147793, 159871, 172933, 187091, 202409, 218971, 236897, + 256279, 277261, 299951, 324503, 351061, 379787, 410857, 444487, + 480881, 520241, 562841, 608903, 658753, 712697, 771049, 834181, + 902483, 976369, 1056323, 1142821, 1236397, 1337629, 1447153, + 1565659, 1693859, 1832561, 1982627, 2144977, 2320627, 2510653, + 2716249, 2938679, 3179303, 3439651, 3721303, 4026031, 4355707, + 4712381, 5098259, 5515729, 5967347, 6456007, 6984629, 7556579, + 8175383, 8844859, 9569143, 10352717, 11200489, 12117689, + 13109983, 14183539, 15345007, 16601593, 17961079, 19431899, + 21023161, 22744717, 24607243, 26622317, 28802401, 31160981, + 33712729, 36473443, 39460231, 42691603, 46187573, 49969847, + 54061849, 58488943, 63278561, 68460391, 74066549, 80131819, + 86693767, 93793069, 101473717, 109783337, 118773397, 128499677, + 139022417, 150406843, 162723577, 176048909, 190465427, + 206062531, 222936881, 241193053, 260944219, 282312799, + 305431229, 330442829, 357502601, 386778277, 418451333, + 452718089, 489790921, 529899637, 573292817, 620239453, + 671030513, 725980837, 785430967, 849749479, 919334987, + 994618837, 1076067617, 1164186217, 1259520799, 1362662261, + 1474249943, 1594975441, 1725587117, 1866894511, 2019773507, + 2185171673, 2364114217, 2557710269, 2767159799, 2993761039, + 3238918481, 3504151727, 3791104843, 4101556399, 4294967291}; + +private: + void FreeSubMaps() + { + /* free all sub maps */ + for (auto &mSubMap : mSubMaps) { + if (mSubMap != nullptr) { + delete[] mSubMap; + mSubMap = nullptr; + } + } + } + + /* + * Description: allocate buckets and init it + * Parameter: bucketCount - the bucket counts + * Parameter: c - the value to be copied + * Parameter: bucketPtr - pointing at the bucket array which is allocated + * NOTES: SECUREC_MEM_MAX_LEN of memset_s function is 2GB + */ + bool NewAndSetBucket(const uint32_t& bucketCount, const int& c, NetHashBucket* &bucketPtr) + { + bucketPtr = new (std::nothrow) NetHashBucket[bucketCount]; + if (HM_UNLIKELY(bucketPtr == nullptr)) { + FreeSubMaps(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "Failed to new hash bucket, probably out of memory"); + return false; + } + + /* make physical page and set to zero */ + size_t bucketsBytes = sizeof(NetHashBucket) * bucketCount; + char* destBytePtr = reinterpret_cast(bucketPtr); + for (size_t i = 0; i < bucketsBytes; i += MEMSET_S_MAX_SIZE) { + size_t bytesOnceSet = (i + MEMSET_S_MAX_SIZE <= bucketsBytes) ? MEMSET_S_MAX_SIZE : (bucketsBytes - i); + auto ret = memset_s(destBytePtr + i, bytesOnceSet, c, bytesOnceSet); + if (ret != 0) { + delete[] bucketPtr; + bucketPtr = nullptr; + FreeSubMaps(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "memset_s failed... size: " + std::to_string(bucketsBytes) + ", error code:" + std::to_string(ret)); + return false; + } + } + return true; + } + + void FreeOverFlowedEntries() + { + for (auto &mSubMap : mSubMaps) { + if (mSubMap == nullptr) { + continue; + } + + /* free overflow entries in one sub map */ + for (uint32_t buckIndex = 0; buckIndex < mBucketCount; ++buckIndex) { + auto curBuck = mSubMap[buckIndex].next; + NetHashBucket *nextOverflowEntryBuck = nullptr; + + /* exit loop when curBuck is null */ + while (curBuck != nullptr) { + /* assign next overflow buck to tmp variable */ + nextOverflowEntryBuck = curBuck->next; + + /* free this overflow bucket */ + mOverflowEntryAlloc->Free(curBuck); + + /* assign next to current */ + curBuck = nextOverflowEntryBuck; + } + } + } + } + + FkvState PutKeyValue(uint64_t key, uint64_t& value, EmbCache::NetHashBucket *buck, + const std::function& beforePutFunc) + { + /* try 8192 times */ + for (uint16_t i = 0; i < 8192; i++) { + /* loop all buckets linked */ + while (buck != nullptr) { + /* if there is an entry to put, just break */ + buck->spinLock.Lock(); + FkvState putRet = buck->Put(key, value, beforePutFunc); + buck->spinLock.UnLock(); + if (putRet == FkvState::FKV_NOT_EXIST) { + current_size++; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_KEY_CONFLICT)) { + return FkvState::FKV_KEY_CONFLICT; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_BEFORE_PUT_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_NO_SPACE)) { + return FkvState::FKV_NO_SPACE; + } + + /* + * if no next bucket exist, just for break, + * else move to next bucket linked + */ + if (buck->next == nullptr) { + break; + } else { + buck = buck->next; + } + } + + /* + * if not put successfully in existing buckets, allocate a new one + * + * NOTES: just allocate memory, don't access new bucket in the spin lock scope, + * if access new bucket, which could trigger physical memory allocation which + * could trigger page fault, that is quite slow. In this case, spin lock + * could occupy too much CPU + */ + auto &lock = buck->spinLock; + lock.Lock(); + /* if other thread allocated new buck already, unlock and continue */ + if (buck->next != nullptr) { + buck = buck->next; + lock.UnLock(); + continue; + } + + /* firstly entered thread allocate new bucket */ + auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); + if (HM_UNLIKELY(newBuck == nullptr)) { + lock.UnLock(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to allocate new bucket"); + return FkvState::FKV_FAIL; + } + /* link to current buck, set buck to new buck */ + buck->next = newBuck; + buck = newBuck; + + /* unlock */ + lock.UnLock(); + } + return FkvState::FKV_FAIL; + } + + void ExtractKeyValInBuck(EmbCache::NetHashBucket *buck, std::vector>& kvVec) + { + while (buck) { + for (size_t k = 0; k < K_KVNUMINBUCKET; k++) { + if (buck->keys[k] == 0) { + continue; + } + kvVec.emplace_back(buck->keys[k].load(), buck->values[k]); + } + buck = buck->next; + } + } +}; +} +#endif // MXREC_MAPPER_BASE_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h new file mode 100644 index 0000000000000000000000000000000000000000..1ad470c5bf4fb9bf7dea441f7874466849b13b39 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h @@ -0,0 +1,280 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_OFFSET_MAPPER_H +#define MXREC_OFFSET_MAPPER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mapper_base.h" + +namespace EmbCache { +class OffsetMapper : public MapperBase { +public: + OffsetMapper() = default; + + ~OffsetMapper() = default; + + OffsetMapper(const OffsetMapper& other): maxCacheSize(other.maxCacheSize), useLength(other.useLength), + validPos(new LimitedSet(*other.validPos)), + evictPos(new LimitedSet(*other.evictPos)), + pos2Key(other.pos2Key), lastBatchPos(other.lastBatchPos), + evictSize(other.evictSize) + { + } + + OffsetMapper& operator=(const OffsetMapper& other) + { + if (this != &other) { + delete validPos; + validPos = nullptr; + delete evictPos; + evictPos = nullptr; + + if (other.validPos != nullptr) { + validPos = new LimitedSet(*other.validPos); + } + if (other.evictPos != nullptr) { + evictPos = new LimitedSet(*other.evictPos); + } + + maxCacheSize = other.maxCacheSize; + useLength = other.useLength; + pos2Key = other.pos2Key; + lastBatchPos = other.lastBatchPos; + evictSize = other.evictSize; + } + return *this; + } + + bool Initialize(uint32_t reserve, uint32_t maxSize = 0) + { + maxCacheSize = maxSize; + useLength = 0; + pos2Key.resize(maxSize); + std::fill(pos2Key.begin(), pos2Key.end(), INVALID_KEY); + try { + validPos = new LimitedSet(maxSize); + evictPos = new LimitedSet(maxSize); + } catch (const std::bad_alloc &e) { + return false; + } + return MapperBase::Initialize(reserve); + } + + void UnInitialize() override + { + delete validPos; + delete evictPos; + validPos = nullptr; + evictPos = nullptr; + MapperBase::UnInitialize(); + } + + FkvState Remove(uint64_t key) + { + return MapperBase::Remove(key, [&](uint64_t value) { + validPos->remove(value); + auto pos = std::find(lastBatchPos.begin(), lastBatchPos.end(), value); + if (pos != lastBatchPos.end()) { + lastBatchPos.erase(pos); + } + evictPos->insert(value); + evictSize++; + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + } + + std::vector> ExportSortedKVPairs() + { + auto koVec = ExportVec(); + std::sort(koVec.begin(), koVec.end(), [](const auto &u, const auto &v) { return u.second < v.second; }); + return koVec; + } + + uint64_t GetFreeLength() + { + return maxCacheSize - useLength + evictSize; + } + + int GetSwapPairsAndKey2Offset(std::vector& keys, KeyOffsetPair& swapInKoPair, + KeyOffsetPair& swapOutKoPair) + { + std::vector swapInKeysID = FilterKeys(keys, swapInKoPair); + + uint64_t swapInCnt = 0; + auto ret = FindInUsedPos(keys, swapInCnt, swapInKeysID, swapInKoPair, swapOutKoPair); + if (ret != ock::ctr::H_OK) { + return ret; + } + + // 剩下的Key从om中分配位置 + ret = FindInOffsetMapper(keys, swapInKoPair, swapInCnt, swapInKeysID); + if (ret != ock::ctr::H_OK) { + return ret; + } + + // 上个batch中的pos可被换出,加入validPos中 + for (uint64_t pos : lastBatchPos) { + if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { + continue; + } + validPos->insert(pos); + } + + // 这里keys都已被替换成offset,这个batch使用的pos在下个batch不能被换出,移出validPos + for (uint64_t pos : keys) { + if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { + continue; + } + validPos->remove(pos); + evictPos->remove(pos); + } + + lastBatchPos = keys; + return ock::ctr::H_OK; + } + + uint32_t GetUsage() + { + return useLength - evictSize; + } + + uint64_t FindInUsedPos(std::vector& keys, uint64_t& swapInCnt, std::vector& swapInKeysID, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) + { + std::vector &swapInKeys = swapInKoPair.first; + std::vector &swapInPos = swapInKoPair.second; + std::vector &swapOutKeys = swapOutKoPair.first; + std::vector &swapOutPos = swapOutKoPair.second; + + // 换出量 = 换入量 - 剩余空间 + uint64_t swapOutNum = swapInKeys.size() <= GetFreeLength() ? 0 : swapInKeys.size() - GetFreeLength(); + swapOutKeys.resize(swapOutNum); + swapOutPos.resize(swapOutNum); + + // 空间不足,前swapOutNum个Key从evictPos中拿可换出位置 + for (uint64_t pos : *evictPos) { + if (swapInCnt == swapInKeys.size()) { + break; + } + // 记录swapInPos + swapInPos[swapInCnt] = pos; + // key->offset + keys[swapInKeysID[swapInCnt]] = pos; + // 放入新key-pos + Put(swapInKeys[swapInCnt], pos); + // 更新pos2Key + pos2Key[pos] = swapInKeys[swapInCnt]; + swapInCnt++; + evictSize--; + } + + uint64_t swapOutCnt = 0; + // 空间不足,前swapOutNum个Key从validPos中拿可换出位置 + for (uint64_t pos : *validPos) { + if (swapOutCnt == swapOutNum) { + break; + } + // 记录swapInPos + swapInPos[swapInCnt] = pos; + // key->offset + keys[swapInKeysID[swapInCnt]] = pos; + // 删除原key-pos,放入新key-pos + uint64_t key = pos2Key[pos]; + MapperBase::Remove(key); + Put(swapInKeys[swapInCnt], pos); + // 记录swapOutKoPair + swapOutKeys[swapOutCnt] = key; + swapOutPos[swapOutCnt] = pos; + // 更新pos2Key + pos2Key[pos] = swapInKeys[swapInCnt]; + swapInCnt++; + swapOutCnt++; + } + + if (swapOutCnt < swapOutNum) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "max cache size is too small"); + return ock::ctr::H_MAX_CACHESIZE_TOO_SMALL; + } + + return ock::ctr::H_OK; + } + + int FindInOffsetMapper(std::vector& keys, KeyOffsetPair& swapInKoPair, uint64_t swapInCnt, + std::vector& swapInKeysID) + { + std::vector &swapInKeys = swapInKoPair.first; + std::vector &swapInPos = swapInKoPair.second; + + for (uint64_t i = swapInCnt; i < swapInKeys.size(); i++) { + swapInPos[i] = useLength++; + if (HM_UNLIKELY(swapInPos[i] >= maxCacheSize)) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "max cache size is too small"); + return ock::ctr::H_MAX_CACHESIZE_TOO_SMALL; + } + // 放入新key-pos + Put(swapInKeys[i], swapInPos[i]); + // 更新pos2Key + pos2Key[swapInPos[i]] = swapInKeys[i]; + // key->offset + keys[swapInKeysID[i]] = swapInPos[i]; + } + return ock::ctr::H_OK; + } + + std::vector FilterKeys(std::vector& keys, KeyOffsetPair &swapInKoPair) + { + std::vector &swapInKeys = swapInKoPair.first; + std::vector &swapInPos = swapInKoPair.second; + + std::vector swapInKeysID; + for (uint64_t i = 0; i < keys.size(); i++) { + // Invalid key 不考虑 + if (HM_UNLIKELY(keys[i] == static_cast(INVALID_KEY))) { + continue; + } + // 在HBM中的key, 原地替换为pos后从validPos中移除 + // 不在HBM中的key,加入swapInKeys,并记录在keys中的下标,用于后续key->offset + if (Find(keys[i], keys[i])) { + validPos->remove(keys[i]); + } else { + swapInKeys.push_back(keys[i]); + swapInKeysID.push_back(i); + } + } + swapInPos.resize(swapInKeys.size()); + return swapInKeysID; + } + +private: + uint32_t maxCacheSize{}; // HBM可容纳embedding条数 + uint32_t useLength{}; // HBM存储的embedding条数 + LimitedSet *validPos{}; // HBM中可被换出的位置 + LimitedSet *evictPos{}; // 淘汰出的位置 + std::vector pos2Key; // HBM中每个位置对应的key + std::vector lastBatchPos; // 上个batch的keys在HBM中占用的pos + uint64_t evictSize; // evictPos的长度 +}; +} +#endif // MXREC_OFFSET_MAPPER_H diff --git a/src/AccCTR/src/factory_impl.cpp b/src/AccCTR/src/factory_impl.cpp index f0f5cdac22eec4d4448e7e8379dcb7b3e69f6cab..654e1d7653a8e342c35c4f575c44179ba1a904c6 100644 --- a/src/AccCTR/src/factory_impl.cpp +++ b/src/AccCTR/src/factory_impl.cpp @@ -54,6 +54,17 @@ int FactoryImpl::CreateUnique(std::shared_ptr &out) return H_OK; } +int FactoryImpl::CreateEmbCacheManager(std::shared_ptr &out) +{ + auto tmp = new (std::nothrow) EmbCache::EmbCacheManagerImpl(); + if (tmp == nullptr) { + return H_NEW_OBJECT_FAILED; + } + + out.reset(dynamic_cast(tmp)); + return H_OK; +} + int FactoryImpl::SetExternalLogFuncInner(ExternalLog logFunc) { auto logger = ExternalLogger::Instance(); diff --git a/src/AccCTR/src/factory_impl.h b/src/AccCTR/src/factory_impl.h index cc1c025a99b6dd629cbcdcfffde1157e38b68e06..aa5cd211c5e0f34a55c855f9931fb4516ecb42a6 100644 --- a/src/AccCTR/src/factory_impl.h +++ b/src/AccCTR/src/factory_impl.h @@ -17,6 +17,7 @@ limitations under the License. #include "include/factory.h" #include "unique/unique_impl.h" +#include "embedding_cache/cache_manager/cache_manager.h" namespace ock { namespace ctr { @@ -27,6 +28,7 @@ public: public: int CreateUnique(std::shared_ptr &out) override; + int CreateEmbCacheManager(std::shared_ptr &out) override; int SetExternalLogFuncInner(ExternalLog logFunc) override; public: diff --git a/src/AccCTR/src/include/CMakeLists.txt b/src/AccCTR/src/include/CMakeLists.txt index c9d2b21563a1605e91de0d6372e5449474d5ac69..7f8b2b6d23434455ead4844b979094b8bafb64c8 100644 --- a/src/AccCTR/src/include/CMakeLists.txt +++ b/src/AccCTR/src/include/CMakeLists.txt @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== -set(INCLUDE_HEADERS factory.h ock_ctr_common_def.h unique.h) +set(INCLUDE_HEADERS factory.h ock_ctr_common_def.h unique.h embedding_cache.h) set(TARGET_INSTALL_INCLUDE ${OUTPUT}/ock_ctr_common/include) diff --git a/src/AccCTR/src/include/embedding_cache.h b/src/AccCTR/src/include/embedding_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..c0468549312a8b0a4834b0774eaa9e10eaadcf3b --- /dev/null +++ b/src/AccCTR/src/include/embedding_cache.h @@ -0,0 +1,341 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef EMBEDDING_CACHE_H +#define EMBEDDING_CACHE_H + +#include +#include +#include +#include + +namespace EmbCache { +using KeyOffsetPair = std::pair, std::vector>; + +class Initializer { +public: + Initializer() = default; + virtual ~Initializer() = default; + + /* * + * 生成随机数 + * @Param emb embedding的首地址 + */ + virtual void GenerateData(float* emb, int embSize) = 0; + uint32_t start{}; // 起始位置 + uint32_t len{}; // 初始化的长度 + float initParam = 1.0; // 初始化器生成的初始值均需要乘以initParam +}; + +enum class InitializerType { + INVALID, + CONSTANT, + TRUNCATED_NORMAL, + RANDOM_NORMAL +}; + +struct ConstantInitializerInfo { + ConstantInitializerInfo() = default; + + ConstantInitializerInfo(float constantValue, float initK); + + float constantValue = 0; // 常量值 + float initK = 1.0; // 初始化出来的值需乘以initK +}; + +struct NormalInitializerInfo { + NormalInitializerInfo() = default; + + NormalInitializerInfo(float mean, float stddev, uint32_t seed, float initK); + + float mean = 0; // 平均值 + float stddev = 0; // 标准差 + uint32_t seed = 0; // 随机数种子 + float initK = 1.0; // 初始化出来的值需乘以initK +}; + +class ConstantInitializer : public Initializer { +public: + ConstantInitializer() = default; + + ConstantInitializer(uint32_t start, uint32_t len, float value, float initK); + + ~ConstantInitializer() override = default; + + void GenerateData(float* emb, int embSize) override; + + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float constantValue = 0; // 常量值 +}; + +class RandomNormalInitializer : public Initializer { +public: + RandomNormalInitializer() = default; + RandomNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo& initInfo); + + ~RandomNormalInitializer() override = default; + + void GenerateData(float* emb, int embSize) override; + + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float mean = 0; // 平均值 + float stddev = 0; // 标准差 + uint32_t seed = 0; // 随机数种子 + + std::default_random_engine generator; // 随机数生成器 + std::normal_distribution distribution; // 正态分布 +}; + +class TruncatedNormalInitializer : public Initializer { +public: + TruncatedNormalInitializer() = default; + + TruncatedNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo& initInfo); + + ~TruncatedNormalInitializer() override = default; + + void GenerateData(float* emb, int embSize) override; + + int boundNum = 2; + + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float mean = 0; // 平均值 + float stddev = 0; // 标准差 + uint32_t seed = 0; // 随机数种子 + + std::default_random_engine generator; // 随机数生成器 + std::normal_distribution distribution; + float minBound = 0; // 下界 + float maxBound = 0; // 上界 +}; + +struct InitializerInfo { + InitializerInfo() = default; + + InitializerInfo(std::string& name, uint32_t start, uint32_t len, ConstantInitializerInfo constantInitializerInfo); + + InitializerInfo(std::string& name, uint32_t start, uint32_t len, NormalInitializerInfo normalInitializerInfo); + + std::string name = ""; // 初始化器的名称 + uint32_t start = 0; // 初始化开始的位置 + uint32_t len = 0; // 待初始化的长度 + InitializerType initializerType = InitializerType::INVALID; + + ConstantInitializerInfo constantInitializerInfo; + NormalInitializerInfo normalInitializerInfo; + + std::shared_ptr initializer; +}; + +struct EmbCacheInfo { + EmbCacheInfo(std::string tableName, uint32_t vocabSize, uint32_t embeddingSize, uint32_t extEmbeddingSize, + uint32_t maxCacheSize) + : tableName(tableName), + vocabSize(vocabSize), + embeddingSize(embeddingSize), + extEmbeddingSize(extEmbeddingSize), + maxCacheSize(maxCacheSize) + { + } + std::string tableName = ""; + uint32_t vocabSize = 0; // host侧的容量(能存多少条embedding) + uint32_t embeddingSize = 0; + uint32_t extEmbeddingSize = 0; // 包含embedding和优化器信息的embedding长度 + uint32_t maxCacheSize = 0; // device侧的容量(能存多少条embedding) +}; + +class EmbCacheManager { +public: + virtual ~EmbCacheManager() = default; + + /* * + * 对当前embInfo对应的table在cache_manager中进行table初始化 + * @Param EmbCacheInfo: embedding cache的初始化信息 + * @Param std::vector 初始化器的信息 + * @Param uint64_t prefillBufferSize emb内存池恒定可用大小 + * @Param uint32_t refillThreadNum emb内存池自动填充线程数 + * @Return errorCode + */ + virtual int CreateCacheForTable(const EmbCacheInfo& embCacheInfo, + const std::vector& initializerInfos, int64_t invalidKey = -1, + uint64_t prefillBufferSize = 500000, uint32_t refillThreadNum = 1) = 0; + + /* * + * 查找当前keys对应的offsets并将本不存在与offsetMapper中的keys插入到offsetMapper中并得到其偏移值offsets, + * 并且当offsetMapper可存放空间不足时,释放可swapOut的keys,获取当前需要被换入换出的keys和offsets的pair + * @Param tableName: 表名 + * @Param keys: 当前batch所有unique的keys + * @Param swapInKoPair: 输出参数,需要换入的Key-offset pair + * @Param swapOutKoPair: 输出参数,需要换出的Key-offset pair + * @Return errorCode + */ + virtual int GetSwapPairsAndKey2Offset(const std::string& tableName, std::vector& keys, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) = 0; + + /* * + * 查询Embedding + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param embAddr: 申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookup(const std::string& tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 查询Embedding的地址 + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param addrs: keys对应的申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookupAddrs(const std::string& tableName, const std::vector& keys, + std::vector& addrs, uint32_t threadNum = 4) = 0; + + /* * + * 查询Embedding并且在查询完成之后删除embedding对应的key。如果多线程使用,严格保证传入的key线程间不会重复(unique + * key),否则可能出现未定义结果 + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param embAddr: 申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookupAndRemove(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum = 4) = 0; + + /* * + * 更新Embedding + * @Param tableName: 表名 + * @Param keys: 待更新的keys,用于查询出每个key在DDR上存放的地址 + * @Param embAddr: 待更新到DDR上的embedding的首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingUpdate(const std::string& tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 在EmbLocalTable中移除keys,并将存储其embedding的内存位置记为可复用 + * @Param tableName: 表名 + * @Param keys: 待移除的keys + * @Return errorCode + */ + virtual int EmbeddingRemove(const std::string& tableName, const std::vector& keys, + uint32_t threadNum = 4) = 0; + + /* * + * 将需要被淘汰的keys从offsetMapper的记录中移除,同时也在EmbLocalTable中移除,并将存储其embedding的内存位置记为可复用 + * @Param tableName: 表名 + * @Param keys: 待淘汰的keys + * @Return errorCode + */ + virtual int RemoveEmbsByKeys(const std::string& tableName, const std::vector& keys) = 0; + + /* * + * 获取所有table names + * @Param allTableNames: 输出参数,用于存放所有的table names + * @Return errorCode + */ + virtual int GetEmbTableNames(std::vector& allTableNames) = 0; + + /* * + * 获取以values为增序排列的当前记录在offsetMapper中所有的keys和values的pairs + * @Param tableName: 表名 + * koVec: 输出参数 + * @Return errorCode + */ + virtual int ExportDeviceKeyOffsetPairs(const std::string& tableName, + std::vector>& koVec) = 0; + + /* * + * 获取当前table的序列化信息 + * @Param tableName: 要序列化的表 + * @Param buffer: 输出参数,存储序列化之后的信息 + * @Return errorCode + */ + virtual int Serialize(const std::string& tableName, std::vector& buffer) = 0; + + /* * + * 将当前table的序列化信息进行反序列化 + * @Param tableName: 要反序列化的表 + * @Param buffer: 输入参数,将buffer中的内容进行反序列化 + * @Return errorCode + */ + virtual int Deserialize(const std::string& tableName, const std::vector& buffer) = 0; + + /* * + * 析构所有embCache,释放内存 + */ + virtual void Destroy() = 0; + + /* * + * 查询表的使用量 + * @Param tableName: 要查询的表 + * @Return 当前表的使用量 + */ + virtual uint32_t GetUsage(const std::string& tableName) = 0; + + /* * + * 获取当前host侧所存储的所有keys及其对应的embeddings和优化器参数 + * @Param tableName: 需要获取信息的table名字 + * @Param keys: 输入参数,输入空vector,获取的存储的所有keys会赋到该vector中 + * @Param embeddings: 输入参数,输入空vector,获取的存储的所有embeddings会赋到该vector中 + * @Param optimizerSlots: 输入参数,输入空vector,获取的存储的所有optimizerSlots会赋到该vector中 + * @Return errorCode + */ + virtual int GetEmbTableInfos(std::string tableName, std::vector& keys, + std::vector>& embeddings, + std::vector>& optimizerSlots) = 0; + + /* * + * 将所需存储的keys及其对应的embeddings和优化器参数传入,来装载LocalEmbeddingTable + * @Param tableName: 需要加载信息的table名字 + * @Param keys: 输入参数,需要加载的所有keys + * @Param embeddings: 输入参数,需要加载的所有embeddings + * @Param optimizerSlots: 输入参数,需要加载的所有optimizerSlots + * @Return errorCode + */ + virtual int LoadEmbTableInfos(std::string tableName, const std::vector& keys, + const std::vector>& embeddings, + const std::vector>& optimizerSlots) = 0; + + /* * + * When switch the channel to eval, backup the current table's offsetMapper object. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int BackUpTrainStatus(const std::string& tableName) = 0; + + /* * + * When switch the eval channel back to train, Recover the current table's offsetMapper object to the backup state. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int RecoverTrainStatus(const std::string& tableName) = 0; + + /* * + * Reset the offsetMapper object to revert to its initialized state after loading. + * @Return errorCode + */ + virtual int ResetOffsetMappers() = 0; +}; +} // namespace EmbCache + +#endif // EMBEDDING_CACHE_H diff --git a/src/AccCTR/src/include/factory.h b/src/AccCTR/src/include/factory.h index 14732cf92dc397bf8ca769223445e0520a087dde..69e8217a792e7d3a07cede50ba33ad0c7ce192dc 100644 --- a/src/AccCTR/src/include/factory.h +++ b/src/AccCTR/src/include/factory.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include "unique.h" +#include "embedding_cache.h" #ifdef __cplusplus @@ -39,11 +40,13 @@ class Factory; using FactoryPtr = std::shared_ptr; using UniquePtr = std::shared_ptr; +using EmbCacheManagerPtr = std::shared_ptr; class Factory { public: virtual ~Factory() = default; virtual int CreateUnique(UniquePtr &out) = 0; + virtual int CreateEmbCacheManager(EmbCacheManagerPtr &out) = 0; virtual int SetExternalLogFuncInner(ExternalLog logFunc) = 0; public: @@ -52,7 +55,7 @@ public: int result = 0; uintptr_t factory = 0; /* dynamic load function */ - if ((result = OckCtrCommonDef::CreatFactory(&factory)) == 0) { + if ((result = OckCtrCommonDef::CreateFactory(&factory)) == 0) { out.reset(reinterpret_cast(factory)); } return result; diff --git a/src/AccCTR/src/include/ock_ctr_common_def.h b/src/AccCTR/src/include/ock_ctr_common_def.h index ed9559961df7ecb9772d193552ce503b19c0fc7f..75e7e9cb4e36b7b8a4aceb25be045ce9146c6a39 100644 --- a/src/AccCTR/src/include/ock_ctr_common_def.h +++ b/src/AccCTR/src/include/ock_ctr_common_def.h @@ -25,7 +25,7 @@ namespace ock { namespace ctr { class OckCtrCommonDef { public: - static int CreatFactory(uintptr_t *factory) + static int CreateFactory(uintptr_t *factory) { static void *handle = nullptr; static std::mutex m; diff --git a/src/AccCTR/src/include/unique.h b/src/AccCTR/src/include/unique.h index 3154a784fb1ad493a310a5eada2dcfa07231c925..1f58f8a44d5dae05dd3888d2a0f027bbe317aa73 100644 --- a/src/AccCTR/src/include/unique.h +++ b/src/AccCTR/src/include/unique.h @@ -58,6 +58,7 @@ using UniqueConf = struct UniqueConfCTR { uint32_t maxThreadNum = 8; // 最大工作线程数 int64_t maxIdVal = 0; // 最大id值 bool trace = false; // 是否开启性能检测,需要配合外部日志输出 + bool performance = false; // 是否开启增强接口,增强接口shardingNum必须是2的幂次方,默认用取模分桶 } __attribute__((packed)); using UniqueIn = struct UniqueInCTR { diff --git a/src/AccCTR/src/unique/unique_func.cpp b/src/AccCTR/src/unique/unique_func.cpp index 64ad6d52dd6b67f6a4b9b85a556d9ce006476d7b..45ac768af2cf51879fdefc330fb6cb31710c54ad 100644 --- a/src/AccCTR/src/unique/unique_func.cpp +++ b/src/AccCTR/src/unique/unique_func.cpp @@ -27,7 +27,6 @@ void Dedup::Insert(uint64_t val) for (int8_t i = 0; i < count; ++i) { if (bucket->data[totalCount] == val) { - TryIncreaseIdCount(bucket->idCount[totalCount]); // found one return; } @@ -38,7 +37,6 @@ void Dedup::Insert(uint64_t val) std::lock_guard lg(bucket->lock); for (int8_t j = totalCount; j < bucket->count; ++j) { if (bucket->data[totalCount] == val) { - TryIncreaseIdCount(bucket->idCount[totalCount]); // found one return; } @@ -47,7 +45,6 @@ void Dedup::Insert(uint64_t val) if (totalCount < n) { bucket->data[totalCount] = val; bucket->count++; - TryIncreaseIdCount(bucket->idCount[totalCount]); return; } } @@ -55,13 +52,6 @@ void Dedup::Insert(uint64_t val) InsertOverflow(val); } -inline void Dedup::TryIncreaseIdCount(std::atomic &val) -{ - if (idCountEnable_) { - val++; - } -} - int32_t Dedup::GetReplaceOffsetUnsafe(uint64_t val) { auto h = static_cast(Hash(val) & bucketCountMask_); @@ -108,7 +98,6 @@ void Dedup::Clear(uint64_t newBucketCountPowerOf2) } bzero(table_, sizeof(Meta) * bucketCount_); overflow_.clear(); - idCountOverflow_.clear(); } void Dedup::NewParameter() @@ -119,10 +108,12 @@ void Dedup::NewParameter() // Time to check the proper size of sharded tables for performance // sake. uint64_t shardedTableSize = 0; - if (std::numeric_limits::max() / n / groupCount_ < newBucketCountPowerOf2) { - shardedTableSize = std::numeric_limits::max(); + if (std::numeric_limits::max() / static_cast(n) / + static_cast(groupCount_) < + newBucketCountPowerOf2) { + shardedTableSize = static_cast(std::numeric_limits::max()); } else { - shardedTableSize = newBucketCountPowerOf2 * n * groupCount_; + shardedTableSize = newBucketCountPowerOf2 * n * static_cast(groupCount_); } int largeCount = 0; @@ -166,6 +157,58 @@ int32_t ShardedDedup::GetFillOffset(const std::vector &totalUniqueSize, } } +void ShardedDedup::GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, + int &index) +{ + if (shardingNumber > 0) { + index += uniqueSizeInBucket[shardingNumber - 1]; + } + + if (usePadding) { + start = shardingNumber * conf.paddingSize; + } else { + start = index; + } +} + +int ShardedDedup::PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg) +{ + if (rc != 0) { + std::stringstream ssm; + ssm << "[" << logMsg << "] memcpy_s failed... dstSize: " << dstSize; + ExternalLogger::PrintLog(LogLevel::ERROR, ssm.str()); + return H_COPY_ERROR; + } else { + return H_OK; + } +} + +int ShardedDedup::HandleIdCountFill(std::vector> &idCount, UniqueOutSelf &uniqueOut) +{ + if (conf.usePadding) { + uint32_t memSize = idCount.size() * sizeof(int32_t); + auto rc = memcpy_s(uniqueOut.idCntFill, memSize, (int32_t *)(idCount.data()), memSize); + if (rc != 0) { + return rc; + } + int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCntFill]"); + if (ret != 0) { + return ret; + } + } else { + uint32_t memSize = idCount.size() * sizeof(int32_t); + auto rc = memcpy_s(uniqueOut.idCnt, memSize, (int32_t *)(idCount.data()), memSize); + if (rc != 0) { + return rc; + } + + int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCnt]"); + if (ret != 0) { + return ret; + } + } + return H_OK; +} size_t ShardedDedup::CalThreadNum() const { diff --git a/src/AccCTR/src/unique/unique_func.h b/src/AccCTR/src/unique/unique_func.h index 39e5a6b3290b3d70c3c83346cbd17d6f19855eb4..0222e4eb5a7efe5e504a16498585f7c52e26b511 100644 --- a/src/AccCTR/src/unique/unique_func.h +++ b/src/AccCTR/src/unique/unique_func.h @@ -30,6 +30,7 @@ limitations under the License. #include #include #include +#include #include "securec.h" #include "common_includes.h" @@ -37,6 +38,14 @@ limitations under the License. namespace ock { namespace ctr { +#ifndef LIKELY +#define LIKELY(x) __builtin_expect(!!(x), 1) +#endif + +#ifndef UNLIKELY +#define UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif + using UniqueOutSelf = struct UniqueSelf { void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 @@ -47,7 +56,7 @@ using UniqueOutSelf = struct UniqueSelf { int uniqueIdCnt = 0; // 每个桶去重后的id个数(需要用户申请) }; -constexpr int UNIQUE_MAX_BUCKET_WIDTH = 5; +constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; template struct Map {}; template <> struct Map { @@ -111,7 +120,7 @@ class Dedup { static constexpr uint32_t K_MINIMAL_WORKLOAD_PER_WORKER = 1 << 12; static constexpr size_t K_ALIGNMENT = 64; static const int kDefaultBucketCount = 1 << 24; - static const int8_t n = 4; + static const int8_t n = UNIQUE_MAX_BUCKET_WIDTH; template struct Meta { static_assert(M <= UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); @@ -119,7 +128,6 @@ class Dedup { volatile int8_t count {}; uint32_t replaceBase {}; volatile uint64_t data[M] {}; - std::atomic idCount[M] {}; } __attribute__((__aligned__(64))); struct Statistics { @@ -152,11 +160,10 @@ public: void Insert(uint64_t val); int32_t GetReplaceOffsetUnsafe(uint64_t val); void InitTable(); - void TryIncreaseIdCount(std::atomic &val); void Clear(uint64_t newBucketCountPowerOf2); void NewParameter(); - template uint32_t UniqueRaw(void *output, uint32_t priorTotal, int32_t *idCount) + template uint32_t UniqueRaw(void *output, uint32_t priorTotal) { uint32_t total = priorTotal; uint32_t replaceOffset = priorTotal; @@ -168,28 +175,24 @@ public: } bucket->replaceBase = replaceOffset; for (int j = 0; j < bucket->count; ++j) { - if (idCountEnable_) { - idCount[total] = bucket->idCount[j]; - } - out[total++] = bucket->data[j]; + out[total] = static_cast(bucket->data[j]); + ++total; } replaceOffset += bucket->count; } auto it = overflow_.begin(); int32_t totalOverflow = 0; while (it != overflow_.end()) { - if (idCountEnable_) { - idCount[total] = idCountOverflow_[it->first]; - } - out[total++] = it->first; + out[total] = it->first; it->second = replaceOffset++; + ++total; ++it; ++totalOverflow; } // set total overflow count stats_.totalUniques = static_cast(total - priorTotal); - stats_.totalOverflowUniques = totalOverflow; + stats_.totalOverflowUniques = static_cast(totalOverflow); return total - priorTotal; } @@ -200,14 +203,13 @@ private: int largeCount_ { 0 }; Meta *table_ {}; std::unordered_map overflow_; - std::unordered_map idCountOverflow_; SpinLockG overflowMutex_; Statistics stats_; bool idCountEnable_ { false }; static inline uint64_t Hash(uint64_t val) { - return val ^ (val >> HASH_L_L) ^ (val >> HASH_L_L) ^ (val >> HASH_H); + return val ^ (val >> HASH_L_L) ^ (val >> HASH_L) ^ (val >> HASH_H); } void InsertOverflow(uint64_t val) @@ -217,10 +219,6 @@ private: if (it == overflow_.end()) { overflow_[val] = 0; } - - if (idCountEnable_) { - idCountOverflow_[val]++; - } } int32_t GetReplaceOffsetFromOverflowUnsafe(uint64_t val) @@ -234,6 +232,7 @@ class ShardedDedup { static constexpr uint32_t K_MINIMAL_WORKLOAD_PER_WORKER = 1 << 13; static constexpr int K_DEFAULT_DUPLICATE_RATIO = 4; static constexpr int K_BUCKET_WIDTH = 4; + static constexpr int CLEAR_WAIT_TIME = 10; public: using DedupT = Dedup; @@ -244,42 +243,42 @@ public: { const int numOfGroupsInShard = groupMethod_.GroupCount(); uint32_t totalSize = conf.desiredSize + (conf.desiredSize >> 1); - while (bucketCountPower2_ * K_BUCKET_WIDTH * numOfGroupsInShard * estimatedDuplicateRatio < totalSize) { + while (static_cast(bucketCountPower2_ * K_BUCKET_WIDTH * numOfGroupsInShard * + estimatedDuplicateRatio) < totalSize) { bucketCountPower2_ <<= 1; } idCountEnable_ = (conf.outputType == OutputType::ENHANCED) && conf.useIdCount; for (int32_t i = 0; i < numOfGroupsInShard; ++i) { auto obj = new DedupT(bucketCountPower2_, numOfGroupsInShard, idCountEnable_); - if (obj == nullptr) { - ExternalLogger::PrintLog(LogLevel::ERROR, "creat object error"); - throw NullptrError(); - } dedupShards_.emplace_back(obj); } } ~ShardedDedup() = default; - void StartNewRound() + int StartNewRound() { for (auto &s : dedupShards_) { s->NewParameter(); } + clearFinish_ = true; + return 0; } public: template int Compute(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut) { - try { - if (!firstEnterFlag_) { - StartNewRound(); - } - } catch (AllocError &) { - ExternalLogger::PrintLog(LogLevel::ERROR, "memory alloc error"); - return H_MEMORY_ALLOC_ERROR; + if (firstEnter_) { + pool_.SetNumThreads(1); + firstEnter_ = false; } - firstEnterFlag_ = false; + + while (!clearFinish_) { + usleep(CLEAR_WAIT_TIME); + } + + clearFinish_ = false; size_t threadNum = CalThreadNum(); partSize = (uniqueIn.inputIdCnt + threadNum - 1) / threadNum; @@ -302,23 +301,29 @@ public: if (conf.outputType == OutputType::ENHANCED) { int totalNumber = 0; for (int i = 0; i < conf.shardingNum; i++) { - totalUniqueSize[i] = totalNumber; + totalUniqueSize[i] = static_cast(totalNumber); if (conf.useSharding) { totalNumber += uniqueOut.uniqueIdCntInBucket[i]; } } } - ret = CalUniqueOut(uniqueIn, uniqueOut, totalUniqueSize); + int size = 1; + if (conf.useIdCount) { + size = conf.usePadding ? conf.paddingSize * conf.shardingNum : uniqueOut.uniqueIdCnt; + } + std::vector> idCount(size); + ret = CalUniqueOut(uniqueIn, uniqueOut, totalUniqueSize, idCount); if (ret != H_OK) { ExternalLogger::PrintLog(LogLevel::ERROR, "CalUniqueOut ERROR"); return ret; } if (conf.outputType == OutputType::ENHANCED) { - HandleTileAndFill(uniqueIn, uniqueOut); + HandleTileAndFill(uniqueOut, idCount); } + pool_.AddTask([this]() { return StartNewRound(); }); return H_OK; } @@ -334,17 +339,22 @@ private: int32_t GetFillOffset(const std::vector &totalUniqueSize, int64_t val, int32_t group); - template int HandleTileAndFill(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut) + void GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, + int &index); + + int PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg); + + int HandleIdCountFill(std::vector> &idCount, UniqueOutSelf &uniqueOut); + + template int HandleTileAndFill(UniqueOutSelf &uniqueOut, std::vector> &idCount) { int ret = H_OK; if (conf.useSharding) { // 使能shard - ret = TileAndFill(uniqueOut.uniqueIdInBucket, uniqueOut.uniqueIdCntInBucket, uniqueOut.uniqueId, - uniqueOut.idCnt, uniqueOut.idCntFill); + ret = TileAndFill(uniqueOut, uniqueOut.uniqueIdCntInBucket, idCount); } else if (!conf.useSharding && conf.useIdCount) { // 不使能shard和使能特征计数 std::vector count; count.emplace_back(uniqueOut.uniqueIdCnt); // 记录去重后id个数 - ret = TileAndFill(uniqueOut.uniqueId, count.data(), uniqueOut.uniqueId, uniqueOut.idCnt, - uniqueOut.idCntFill); + ret = TileAndFill(uniqueOut, count.data(), idCount); } if (ret != H_OK) { @@ -363,37 +373,37 @@ private: uint64_t inGroupTotal; if (conf.outputType == OutputType::ENHANCED) { if (conf.useSharding && conf.useIdCount) { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total, - uniqueOut.idCnt); // 特征计数使能和shard同时使能 - uniqueOut.uniqueIdCntInBucket[j] = inGroupTotal; + inGroupTotal = + dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total); // 特征计数使能和shard同时使能 + uniqueOut.uniqueIdCntInBucket[j] = static_cast(inGroupTotal); } else if (!conf.useSharding && conf.useIdCount) { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, - uniqueOut.idCnt); // 特征计数使能和shard不使能 + inGroupTotal = + dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total); // 特征计数使能和shard不使能 } else if (conf.useSharding && !conf.useIdCount) { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total, - nullptr); // 特征计数使能和shard不使能 - uniqueOut.uniqueIdCntInBucket[j] = inGroupTotal; + inGroupTotal = + dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total); // 特征计数使能和shard不使能 + uniqueOut.uniqueIdCntInBucket[j] = static_cast(inGroupTotal); } else { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, - nullptr); // 特征计数不使能和shard不使能,跟普通unique对等 + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, + total); // 特征计数不使能和shard不使能,跟普通unique对等 } } else { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, nullptr); + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total); } - total += inGroupTotal; + total += static_cast(inGroupTotal); } uniqueOut.uniqueIdCnt = total; } template - int TileAndFill(void *uniqueIdInBucket, const int32_t *uniqueSizeInBucket, void *uniqueIds, const int32_t *idCnt, - int32_t *idCntFill) + int TileAndFill(UniqueOutSelf &uniqueOut, const int32_t *uniqueSizeInBucket, + std::vector> &idCount) { int start = 0; int index = 0; - auto uIdInBucket = TypeTrans(uniqueIdInBucket); - auto uIds = TypeTrans(uniqueIds); + auto uIdInBucket = TypeTrans(conf.useSharding ? uniqueOut.uniqueIdInBucket : uniqueOut.uniqueId); + auto uIds = TypeTrans(uniqueOut.uniqueId); for (int i = 0; i < conf.shardingNum; i++) { GetIndexAndStart(uniqueSizeInBucket, conf.usePadding, i, start, index); @@ -417,35 +427,31 @@ private: if (conf.useIdCount && conf.usePadding) { memSize = uniqueSizeInBucket[i] * sizeof(int32_t); - rc = memcpy_s(idCntFill + start, memSize, idCnt + index, memSize); - ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCntFill]"); + rc = memcpy_s(uniqueOut.idCnt + index, memSize, (int32_t *)(idCount.data()) + start, + memSize); // 填充idCount + ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCnt]"); + } + + if (ret != 0) { + return ret; } + } + + if (conf.useIdCount) { + int ret = HandleIdCountFill(idCount, uniqueOut); if (ret != 0) { return ret; } } if (conf.usePadding) { - HandleFill(uIds, uniqueSizeInBucket, idCntFill); + HandleFill(uIds, uniqueSizeInBucket); } return H_OK; } - int PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg) - { - if (rc != 0) { - std::stringstream ssm; - ssm << "[" << logMsg << "] memcpy_s failed... dstSize: " << dstSize; - ExternalLogger::PrintLog(LogLevel::ERROR, ssm.str()); - return H_COPY_ERROR; - } else { - return H_OK; - } - } - - template - void HandleFill(typename Map::type *uIds, const int32_t *uniqueSizeInBucket, int32_t *idCntFill) + template void HandleFill(typename Map::type *uIds, const int32_t *uniqueSizeInBucket) { int start = 0; int index = 0; @@ -457,26 +463,6 @@ private: for (int j = 0; j < fillLen; j++) { uIds[start + uniqueSizeInBucket[i] + j] = conf.paddingVal; // padding填充 } - - if (idCntFill != nullptr) { - for (int y = 0; y < fillLen; y++) { - idCntFill[start + uniqueSizeInBucket[i] + y] = 0; // 特征计数填充 - } - } - } - } - - void GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, - int &index) - { - if (shardingNumber > 0) { - index += uniqueSizeInBucket[shardingNumber - 1]; - } - - if (usePadding) { - start = shardingNumber * conf.paddingSize; - } else { - start = index; } } @@ -491,13 +477,18 @@ private: tasks.push_back([this, val, start, end, &ret]() { for (uint64_t j = start; j < end; ++j) { auto value = val[j]; - if (value > conf.maxIdVal) { + if (UNLIKELY(value > conf.maxIdVal)) { ExternalLogger::PrintLog(LogLevel::ERROR, "id val is larger than maxIdVal"); ret = H_ID_LARGE; break; } - auto group = groupMethod_.GroupId(value); - dedupShards_[group]->Insert(value); + + if (conf.performance) { + dedupShards_[value & (conf.shardingNum - 1)]->Insert(value); + } else { + auto group = groupMethod_.GroupId(value); + dedupShards_[group]->Insert(value); + } } }); } @@ -518,7 +509,8 @@ private: } template - int CalUniqueOut(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut, std::vector &totalUniqueSize) + int CalUniqueOut(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut, std::vector &totalUniqueSize, + std::vector> &idCount) { uint32_t *beginPtr = uniqueOut.index; uint32_t *finishPtr = beginPtr + uniqueIn.inputIdCnt; @@ -531,18 +523,32 @@ private: if (partEndPtr > finishPtr) { partEndPtr = finishPtr; } - if (partBeginPtr < partEndPtr) { - // Due to cacheline alignment computation, the actual number of - // threads created here may not match threadNum exactly but - // should be +/-1 off. - tasks.push_back([this, val, beginPtr, partBeginPtr, partEndPtr, totalUniqueSize]() { - for (uint32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { + + if (partBeginPtr >= partEndPtr) { + partBeginPtr = partEndPtr; + partEndPtr += partSize; + continue; + } + + // Due to cacheline alignment computation, the actual number of + // threads created here may not match threadNum exactly but + // should be +/-1 off. + tasks.push_back([this, val, beginPtr, partBeginPtr, partEndPtr, totalUniqueSize, &idCount]() { + for (uint32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { + int32_t fillOffset; + if (conf.performance) { + fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], + val[ptr - beginPtr] & (conf.shardingNum - 1)); + } else { auto group = groupMethod_.GroupId(val[ptr - beginPtr]); - int32_t fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], group); - *ptr = fillOffset; + fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], group); } - }); - } + *ptr = fillOffset; + if (LIKELY(conf.useIdCount)) { + idCount[fillOffset]++; + } + } + }); partBeginPtr = partEndPtr; partEndPtr += partSize; } @@ -567,8 +573,10 @@ private: UniqueConf conf; std::vector> dedupShards_ {}; uint32_t partSize; - bool firstEnterFlag_ = false; + bool clearFinish_ = true; bool idCountEnable_ { false }; + ThreadPoolAsync pool_; + bool firstEnter_ = true; }; } } diff --git a/src/AccCTR/src/unique/unique_impl.cpp b/src/AccCTR/src/unique/unique_impl.cpp index 77113214234b42f84caef53c63115327988f011f..800f21de5c05c9f085cdd670edc23d4a964b0091 100644 --- a/src/AccCTR/src/unique/unique_impl.cpp +++ b/src/AccCTR/src/unique/unique_impl.cpp @@ -228,6 +228,14 @@ int UniqueImpl::CheckEnhancedUniqueConf(const UniqueConf &conf) if (CheckInputZero(conf.shardingNum, "shardingNum")) { return H_NUM_SMALL; } + if (conf.performance) { + bool isExponentOfTwo = + (conf.shardingNum > 0) && ((conf.shardingNum & (conf.shardingNum - 1)) == 0); // 判断是不是2的N次幂 + if (!isExponentOfTwo) { + ExternalLogger::PrintLog(LogLevel::ERROR, "if performance is true, shardingNum must be 2^N"); + return H_ERROR; + } + } } return H_OK; diff --git a/src/AccCTR/src/unique/unique_impl.h b/src/AccCTR/src/unique/unique_impl.h index f4c45fded514b4601832fdf2bd0b3483fae97c5b..e37a58dbf2d595754b8f90af3b1b182968baacb2 100644 --- a/src/AccCTR/src/unique/unique_impl.h +++ b/src/AccCTR/src/unique/unique_impl.h @@ -43,7 +43,7 @@ private: private: ShardedDedup *unique = nullptr; - UniqueConf uniqueConf {}; + UniqueConf uniqueConf{}; }; } } diff --git a/src/AccCTR/tests/tools/create_fake_id.py b/src/AccCTR/tests/tools/create_fake_id.py index fc0f1f8ef53ccfc9800cbfdf0c082291b31ac7a4..aa42f0714d3e1b800d28583628ce6f1ff43c638e 100644 --- a/src/AccCTR/tests/tools/create_fake_id.py +++ b/src/AccCTR/tests/tools/create_fake_id.py @@ -68,12 +68,6 @@ def write_data(file_name, x, y, dup): def main(): - # 300w id去重率20% - # 6x + y =300 - # x + y = 60 - # x = 48 y =12 - write_data('data20.txt', 48*10000, 12*10000, 6) - # 300w id去重率30% # 6x + y =300 # x + y = 90 diff --git a/src/AccCTR/tests/ut/conf/toolchain.cmake b/src/AccCTR/tests/ut/conf/toolchain.cmake new file mode 100644 index 0000000000000000000000000000000000000000..bd6617e4b7c9112ab810aa6cba64bd05e2c3da62 --- /dev/null +++ b/src/AccCTR/tests/ut/conf/toolchain.cmake @@ -0,0 +1,24 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# 添加编译选项 +option(USE32BIT "Use 32-Bit" OFF) +if(USE32BIT) + add_compile_options(-m32) + add_link_options(-m32) +endif() + +add_compile_options(-Wall) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 11) \ No newline at end of file diff --git a/src/AccCTR/tests/ut/src/CMakeLists.txt b/src/AccCTR/tests/ut/src/CMakeLists.txt index a4c631e885b9f810d7053e231f26c81b87d5dfd3..3da582446b667f4268486869b3865e22748f7d11 100644 --- a/src/AccCTR/tests/ut/src/CMakeLists.txt +++ b/src/AccCTR/tests/ut/src/CMakeLists.txt @@ -19,6 +19,11 @@ set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) set(OCK_CTR_SRC_DIR ${PROJECT_SOURCE_DIR}/src) message("src" ${OCK_CTR_SRC_DIR}) +# 包含所有组件的cmake +include("${CMAKE_CURRENT_SOURCE_DIR}/../conf/toolchain.cmake") +set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../src) +set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../) + file(GLOB_RECURSE TEST_UNIQUE_FILES *.cpp *.h) add_executable(test_unique_files ${TEST_UNIQUE_FILES}) include_directories(${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/include) @@ -29,17 +34,36 @@ SET(LIB_3RD_GTEST ${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/lib64/lib message(${OCK_CTR_SRC_DIR}/include) +# 添加库文件的搜索路径 +target_link_directories(test_unique_files + PUBLIC + ${PROJECT_SOURCE_DIR}/output/ock_ctr_common/lib + ) +# 添加头文件的搜索路径 target_include_directories(test_unique_files PUBLIC - ${OCK_CTR_SRC_DIR}/include) + ${OCK_CTR_SRC_DIR}/include + ${PROJECT_SOURCE_DIR} + ${OCK_CTR_SRC_DIR}/common/util + ) +# 用来指定要链接的库 target_link_libraries(test_unique_files PUBLIC -Wl,--start-group + _ock_ctr_common pthread dl ${LIB_3RD_GTEST} ${LIB_3RD_GMOCK} -Wl,--end-group) +# 打印构建选项 +get_target_property(COMPILE_FLAGS test_unique_files COMPILE_OPTIONS) +get_target_property(LINK_FLAGS test_unique_files LINK_OPTIONS) +message(STATUS "Compiler id: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "Compile flags: ${COMPILE_FLAGS}") +message(STATUS "Link flags: ${LINK_FLAGS}") +message(STATUS "Build Type: ${CMAKE_BUILD_TYPE}") + diff --git a/src/AccCTR/tests/ut/src/common.h b/src/AccCTR/tests/ut/src/common.h new file mode 100644 index 0000000000000000000000000000000000000000..7302d10cafc96827d01b9a609d456a970105a9d4 --- /dev/null +++ b/src/AccCTR/tests/ut/src/common.h @@ -0,0 +1,64 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef CTR_COMMON_H +#define CTR_COMMON_H +#include + +#include "factory.h" + +extern ock::ctr::FactoryPtr factory; + +enum CTRLogLevel { + DEBUG = 0, + INFO, + WARN, + ERROR, +}; + +class SimpleThreadPool { +public: + static void SyncRun(const std::vector> &tasks) + { + std::vector> futs; + for (auto &task : tasks) { + futs.push_back(std::async(task)); + } + for (auto &fut : futs) { + fut.wait(); + } + } +}; + +static void CTRLog(int level, const char *msg) +{ + switch (level) { + case CTRLogLevel::DEBUG: + std::cout << "DEBUG:" << msg << std::endl; + break; + case CTRLogLevel::INFO: + std::cout << "INFO:" << msg << std::endl; + break; + case CTRLogLevel::WARN: + std::cout << "WARN:" << msg << std::endl; + break; + case CTRLogLevel::ERROR: + std::cout << "ERROR:" << msg << std::endl; + break; + default: + break; + } +} + +#endif // CTR_COMMON_H diff --git a/src/AccCTR/tests/ut/src/emb_cache_test.cpp b/src/AccCTR/tests/ut/src/emb_cache_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dda5423c1439882214a55a636dc6f0ae2371c240 --- /dev/null +++ b/src/AccCTR/tests/ut/src/emb_cache_test.cpp @@ -0,0 +1,1999 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include + +#include "common/util/error_code.h" +#include "emb_cache_test.h" +#include "common.h" + +using namespace std; +using namespace ock::ctr; + +FactoryPtr factory; +EmbCacheManagerPtr embCache = nullptr; + +std::vector GenKeys(uint64_t n, uint32_t seed = 0, uint64_t min = 0, uint64_t max = UINT64_MAX) +{ + std::mt19937 generator(seed); + std::uniform_int_distribution distribution(min, max); + std::vector data(n); + for (uint64_t &x : data) { + x = distribution(generator); + } + sort(data.begin(), data.end()); + data.erase(unique(data.begin(), data.end()), data.end()); + return data; +} + +std::vector GenUniqueKeys(uint64_t n) +{ + std::vector data(n); + for (uint64_t i = 0; i < n; i++) { + data[i] = i; + } + return data; +} + +EmbCacheManagerPtr EmbCacheTest::SimpleCreateTable(std::string tableName, uint32_t hostVocabSize, + uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, pair normalPara, + float constPara) +{ + factory->CreateEmbCacheManager(embCache); + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + EmbCache::NormalInitializerInfo normalInitializerInfo(normalPara.first, normalPara.second, 0, 1.0); + std::string normalInitializeName = "random_normal_initializer"; + EmbCache::InitializerInfo normalInitializeInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + EmbCache::ConstantInitializerInfo constantInitializerInfo(constPara, 1.0); + std::string constantInitializeName = "constant_initializer"; + + std::vector initializeInfos(extEmbeddingSize / embeddingSize); + initializeInfos[0] = normalInitializeInfo; + for (uint64_t i = 1; i < initializeInfos.size(); i++) { + initializeInfos[i] = EmbCache::InitializerInfo(constantInitializeName, embeddingSize * i, embeddingSize, + constantInitializerInfo); + } + int ret = embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize, 1); + if (ret != H_OK) { + string msg = "CreateCacheForTable Failed. ret: " + std::to_string(ret); + CTRLog(CTRLogLevel::ERROR, msg.c_str()); + return nullptr; + } + return embCache; +} + +EmbCacheManagerPtr EmbCacheTest::ConstZeroCreateTable(std::string tableName, uint32_t hostVocabSize, + uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, uint64_t prefillBufferSize, + uint8_t prefillThreadNum) +{ + factory->CreateEmbCacheManager(embCache); + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.0, 1.0); + std::string constantInitializeName = "constant_initializer"; + + std::vector initializeInfos = { EmbCache::InitializerInfo(constantInitializeName, 0, + extEmbeddingSize, constantInitializerInfo) }; + int ret = embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, prefillBufferSize, prefillThreadNum); + if (ret != H_OK) { + string msg = "CreateCacheForTable Failed. ret: " + std::to_string(ret); + CTRLog(CTRLogLevel::ERROR, msg.c_str()); + return nullptr; + } + return embCache; +} + +void EmbCacheTest::SetUpTestCase() +{ + Factory::Create(factory); + factory->SetExternalLogFuncInner(CTRLog); +} + +void EmbCacheTest::TearDownTestCase() {} + +void EmbCacheTest::SetUp() {} + +void EmbCacheTest::TearDown() +{ + if (embCache != nullptr) { + embCache->Destroy(); + embCache = nullptr; + } +} + +TEST_F(EmbCacheTest, ConstantInitializerInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========ConstantInitializerInfo start============="); + + // 正确初始化ConstantInitializerInfo结构体,无日志信息反馈 + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + CTRLog(CTRLogLevel::INFO, "===========ConstantInitializerInfo end============="); +} + +TEST_F(EmbCacheTest, NormalInitializerInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========NormalInitializerInfo start============="); + // 正确初始化NormalInitializerInfo结构体,无日志信息反馈 + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + // 标准差负值数学意义不明,传入负值问题用户自己承担 + EmbCache::NormalInitializerInfo normalInitializerInfo_ne_dev(0, -0.05, 0, 1.0); + CTRLog(CTRLogLevel::INFO, "===========NormalInitializerInfo end============="); +} + +TEST_F(EmbCacheTest, InitializerInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========InitializerInfo start============="); + uint32_t embeddingSize = 13; + + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + + // 传入的std::string不为"constant_initializer" 日志打印"Invalid Initializer Type." + std::string not_a_initializer_name = "not_a_initializer_name"; + EmbCache::InitializerInfo constantInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, constantInitializerInfo); + + // 传入的std::string不为"constant_initializer" 日志打印"Invalid Initializer Type." + not_a_initializer_name = ""; + constantInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, constantInitializerInfo); + + // 正确初始化InitializeInfo结构体,无日志信息反馈 + std::string constantInitializeName = "constant_initializer"; + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo); + + // 传入的std::string不为"random_normal_initializer"或truncated_normal_initializer 日志打印"Invalid Initializer + // Type." + not_a_initializer_name = "not_a_initializer_name"; + EmbCache::InitializerInfo normalInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, normalInitializerInfo); + + // 传入的std::string不为"random_normal_initializer"或truncated_normal_initializer 日志打印"Invalid Initializer + // Type." + not_a_initializer_name = ""; + normalInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, normalInitializerInfo); + + // 正确初始化InitializeInfo结构体,无日志信息反馈 + std::string normalInitializeName = "random_normal_initializer"; + normalInitializeInfo = EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + // 正确初始化InitializeInfo结构体,无日志信息反馈 + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + EmbCache::InitializerInfo truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, 0, embeddingSize, normalInitializerInfo); + + CTRLog(CTRLogLevel::INFO, "===========InitializerInfo end============="); +} + +TEST_F(EmbCacheTest, EmbCacheInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========EmbCacheInfo start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + // 正确初始化EmbCacheInfo结构体,无日志信息反馈 + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + CTRLog(CTRLogLevel::INFO, "===========EmbCacheInfo end============="); +} + +TEST_F(EmbCacheTest, CreateCacheForTable) +{ + factory->CreateEmbCacheManager(embCache); + CTRLog(CTRLogLevel::INFO, "===========CreateCacheForTable start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, {}, -1, hostVocabSize), H_INITIALIZER_INVALID); + + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + std::string normalInitializeName = "random_normal_initializer"; + EmbCache::InitializerInfo normalInitializeInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + // 空initializer 日志打印出"Initializer is nullptr" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, { {}, {} }, -1, hostVocabSize), H_INITIALIZER_INVALID); + + normalInitializeInfo.initializer = nullptr; + // 空initializer 日志打印出"Initializer is nullptr" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, { normalInitializeInfo }, -1, hostVocabSize), + H_INITIALIZER_INVALID); + + normalInitializeInfo = EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + std::string constantInitializeName = "constant_initializer"; + EmbCache::InitializerInfo constantInitializeInfo(constantInitializeName, embeddingSize, embeddingSize + 1, + constantInitializerInfo); + std::vector initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + + // initializerInfos的区间之间有重叠或者遗漏 日志打印出"Initializers got coverage problems" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_INITIALIZER_INVALID); + + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize + 1, embeddingSize, constantInitializerInfo); + initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + // initializerInfos的区间之间有重叠或者遗漏 日志打印出"Initializers got coverage problems" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_INITIALIZER_INVALID); + + + embCacheInfo.extEmbeddingSize = extEmbeddingSize; + std::string not_a_initializer_name = "not_a_initializer_name"; + constantInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, constantInitializerInfo); + initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + + // 传入的Initializer的name不符要求 日志打印出"Invalid Initializer Type.\nInitializer is nullptr" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_INITIALIZER_INVALID); + + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize, constantInitializerInfo); + initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + + embCacheInfo.extEmbeddingSize++; + + // 传入的embInfo中的传入的extEmbeddingSize并非embeddingSize的整数倍 日志打印出"extEmbeddingSize = embeddingSize + + // optimizerSize, which is divisible by embeddingSize" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), + H_EXT_EMBEDDING_SIZE_INVALID); + + embCacheInfo.maxCacheSize = 100; + // maxCacheSize>vocabSize 日志打印出"vocabSize must be greater than or equal to maxCacheSize" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), + H_HOST_VOCAB_SIZE_TOO_SMALL); + embCacheInfo.maxCacheSize = devVocabSize; + + embCacheInfo.extEmbeddingSize = 0; + // extEmbeddingSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.extEmbeddingSize = extEmbeddingSize; + + embCacheInfo.embeddingSize = 0; + // embeddingSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.embeddingSize = embeddingSize; + + embCacheInfo.vocabSize = 0; + // vocabSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.vocabSize = hostVocabSize; + + embCacheInfo.maxCacheSize = 0; + // maxCacheSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.maxCacheSize = devVocabSize; + + embCacheInfo.tableName = ""; + // 传入的tableName空 日志打印出"tableName can not be empty" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_TABLE_NAME_EMPTY); + + embCacheInfo.tableName = + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "0000000001000000000100000000010001"; + // 传入的tableName长度正好为长度上限1024 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_OK); + + embCacheInfo.tableName = + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100012"; + // 传入的tableName长度为1025超过了长度上限 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_TABLE_NAME_TOO_LONG); + embCacheInfo.tableName = tableName; + + // 正常创建 日志中不会打印异常信息 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_OK); + + // 重复创建同名Table 日志打印出"This table has already been created" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), + H_TABLE_CREATE_DUPLICATE); + embCache->Destroy(); + + // Destroy后仍能正常创建 日志中不会打印异常信息 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_OK); + embCache->Destroy(); + + // prefill单线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 1), H_OK); + embCache->Destroy(); + + // prefill多线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 3), H_OK); + embCache->Destroy(); + + // prefill多线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 0), H_THREAD_NUM_ERROR); + embCache->Destroy(); + + // prefill过多线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 10000), H_THREAD_NUM_ERROR); + embCache->Destroy(); + + // prefill 正常buffersize + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 1), H_OK); + embCache->Destroy(); + + // prefill 超大buffersize + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 10, 1), H_PREFILL_BUFFER_SIZE_INVALID); + embCache->Destroy(); + + // prefill 0buffersize + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 0, 1), H_PREFILL_BUFFER_SIZE_INVALID); + CTRLog(CTRLogLevel::INFO, "===========CreateCacheForTable end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_ADDRS) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_ADDRS start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + std::vector addrs; + + lookupKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + // lookupkeys 为空 + lookupKeys = {}; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + lookupKeys = { 0 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs("not_a_table", lookupKeys, addrs), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tooLongTableName, lookupKeys, addrs), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 5 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_HOST_VOCAB_SIZE_TOO_SMALL); + + lookupKeys = { 5 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, 1), H_HOST_VOCAB_SIZE_TOO_SMALL); + + lookupKeys = { 0, 1, 4 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + lookupKeys = { 0, 1, 4 }; + uint32_t threadNum = std::thread::hardware_concurrency(); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, threadNum + 1), H_THREAD_NUM_ERROR); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, threadNum), H_OK); + // 单线程lookup + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, 0), H_THREAD_NUM_ERROR); + embCache->Destroy(); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_ADDRS_DATA) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_ADDRS_DATA start============="); + factory->CreateEmbCacheManager(embCache); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 39; + uint32_t devVocabSize = 100000; + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::string normalInitializeName = "random_normal_initializer"; + std::string constantInitializeName = "constant_initializer"; + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + // 加入所有初始化器的所有分支 + std::vector initializeInfos = { + EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo), + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, 0, normalInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize, constantInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, 2 * embeddingSize, 0, constantInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 2 * embeddingSize, embeddingSize, + normalInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 3 * embeddingSize, 0, normalInitializerInfo), + }; + // 正确创建 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos), H_OK); + std::vector lookupKeys; + std::vector addrs; + lookupKeys = GenKeys(hostVocabSize, 123321); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + long double sum = 0.0; + long double cnt = 0.0; + long double accum = 0.0; + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + // normalInitializer 生成数据 + for (uint32_t j = 0; j < embeddingSize; j++) { + sum += addrs[i][j]; + cnt++; + } + + // constantInitializer 生成数据 + for (uint32_t j = embeddingSize; j < 2 * embeddingSize; j++) { + ASSERT_LE(std::abs(addrs[i][j] - 0.233), 1e-6f); + } + // truncatedNormalInitializer 生成数据 + for (uint32_t j = 2 * embeddingSize; j < 3 * embeddingSize; j++) { + // 在[-2*stddev, 2*stddev]范围中 + ASSERT_LE(std::abs(addrs[i][j]), 0.1f + 1e-6f); + } + } + + long double mean = sum / cnt; + for (uint32_t i = 0; i < lookupKeys.size(); ++i) { + for (uint32_t j = 0; j < embeddingSize; j++) { + accum += (addrs[i][j] - mean) * (addrs[i][j] - mean); + } + } + long double stdev = sqrt(accum / cnt); + ASSERT_LE(std::abs(mean), 5e-6f); + ASSERT_LE(std::abs(stdev - 0.05), 5e-6f); + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_ADDRS_DATA end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_300W) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_300W start============="); + factory->CreateEmbCacheManager(embCache); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 39; + uint32_t devVocabSize = 100000; + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::string normalInitializeName = "random_normal_initializer"; + std::string constantInitializeName = "constant_initializer"; + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + // 加入所有初始化器的所有分支 + std::vector initializeInfos = { + EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo), + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, 0, normalInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize, constantInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, 2 * embeddingSize, 0, constantInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 2 * embeddingSize, embeddingSize, + normalInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 3 * embeddingSize, 0, normalInitializerInfo), + }; + // 正确创建 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos), H_OK); + std::vector lookupKeys; + float *addr; + lookupKeys = GenKeys(hostVocabSize, 123321); + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + + long double sum = 0.0; + long double cnt = 0.0; + long double accum = 0.0; + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + // normalInitializer 生成数据 + for (uint32_t j = 0; j < embeddingSize; j++) { + sum += addr[i * extEmbeddingSize + j]; + cnt++; + } + + // constantInitializer 生成数据 + for (uint32_t j = embeddingSize; j < 2 * embeddingSize; j++) { + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - 0.233), 1e-6f); + } + // truncatedNormalInitializer 生成数据 + for (uint32_t j = 2 * embeddingSize; j < 3 * embeddingSize; j++) { + // 在[-2*stddev, 2*stddev]范围中 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j]), 0.1f + 1e-6f); + } + } + + long double mean = sum / cnt; + for (uint32_t i = 0; i < lookupKeys.size(); ++i) { + for (uint32_t j = 0; j < embeddingSize; j++) { + accum += (addr[i * extEmbeddingSize + j] - mean) * (addr[i * extEmbeddingSize + j] - mean); + } + } + long double stdev = sqrt(accum / cnt); + ASSERT_LE(std::abs(mean), 5e-6f); + ASSERT_LE(std::abs(stdev - 0.05), 5e-6f); + free(addr); + CTRLog(CTRLogLevel::INFO, "===========GenerateData end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_AND_REMOVE) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *addr; + + lookupKeys = { 0, 1, 2, 3, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_OK); + free(addr); + + // lookupkeys 为空 + lookupKeys = {}; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_OK); + free(addr); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookupAndRemove("not_a_table", lookupKeys, addr), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tooLongTableName, lookupKeys, addr), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_ADDRESS_NULL); + + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + uint32_t threadNum = std::thread::hardware_concurrency(); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, threadNum + 1), H_THREAD_NUM_ERROR); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, threadNum), H_OK); + // 单线程lookup + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 0), H_THREAD_NUM_ERROR); + free(addr); + embCache->Destroy(); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_AND_REMOVE_2) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_2 start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 200; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *addr; + + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 2; j++) { + lookupKeys.emplace_back(i); + } + } + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_OK); + free(addr); + embCache->Destroy(); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_2 end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *addr; + + lookupKeys = { 0, 1, 2, 3, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + // lookupkeys 为空 + lookupKeys = {}; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookup("not_a_table", lookupKeys, addr), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->EmbeddingLookup(tooLongTableName, lookupKeys, addr), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 5 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_HOST_VOCAB_SIZE_TOO_SMALL); + free(addr); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_ADDRESS_NULL); + + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + uint32_t threadNum = std::thread::hardware_concurrency(); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, threadNum + 1), H_THREAD_NUM_ERROR); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, threadNum), H_OK); + // 单线程lookup + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, 0), H_THREAD_NUM_ERROR); + free(addr); + embCache->Destroy(); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_AND_REMOVE_300W) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_300W start============="); + std::string tableName = "test_table"; + std::vector lookupKeys; + float *newEmb; + + // 300w个key + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 100000; + embCache = ConstZeroCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + lookupKeys = GenUniqueKeys(hostVocabSize); + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + newEmb[i * extEmbeddingSize + j] = i + 0.01f * j; // 生成特殊数据 + } + } + CTRLog(CTRLogLevel::INFO, "gen done"); + // 把特殊数据放到表中 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + CTRLog(CTRLogLevel::INFO, "EmbeddingUpdate done"); + + float *addr; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + // 查询特殊数据 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + CTRLog(CTRLogLevel::INFO, "EmbeddingLookup done"); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - (i + 0.01f * j)), 1e-6f); + } + } + free(addr); + addr = nullptr; + + // Remove之后再Lookup,观察这些embedding是不是被正确remove + // 首先确认EmbeddingLookupAndRemove不会报错 + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 4), H_OK); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - (i + 0.01f * j)), 1e-6f); + } + } + free(addr); + addr = nullptr; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + // 然后再lookup,并确保lookup不会报错 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 因为用const zero初始化, EmbeddingLookupAndRemove之后再lookup,结果应该全是0 + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - 0), 1e-6f); + } + } + free(addr); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_300W end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_UPDATE_300W) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE_300W start============="); + std::string tableName = "test_table"; + std::vector lookupKeys; + float *newEmb; + + // 300w个key + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 100000; + embCache = ConstZeroCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize, 50000, 6); + lookupKeys = GenKeys(hostVocabSize, 123321); + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + newEmb[i * extEmbeddingSize + j] = i + 0.01f * j; // 生成特殊数据 + } + } + CTRLog(CTRLogLevel::INFO, "gen done"); + // 把特殊数据放到表中 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + CTRLog(CTRLogLevel::INFO, "EmbeddingUpdate done"); + + float *addr; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + // 查询特殊数据 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + CTRLog(CTRLogLevel::INFO, "EmbeddingLookup done"); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - (i + 0.01f * j)), 1e-6f); + } + } + // Remove之后再Lookup,观察这些embedding是不是被正确remove + // 首先确认remove不会报错 + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + // 然后再lookup,并确保lookup不会报错 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 因为用const zero初始化, 删除之后再lookup,结果应该全是0 + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - 0), 1e-6f); + } + } + free(addr); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE_300W end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_UPDATE) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *newEmb; + + lookupKeys = { 0, 1, 2, 3, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + + // 更新存在的table,应当正常更新 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + lookupKeys = { 0 }; + newEmb = nullptr; + // 更新不存在的table + ASSERT_EQ(embCache->EmbeddingUpdate("not_a_table", lookupKeys, newEmb), H_TABLE_NOT_EXIST); + + // 表名超过上限 + ASSERT_EQ(embCache->EmbeddingUpdate(tooLongTableName, lookupKeys, newEmb), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 5 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + + // 当前embLocalTable中存储的key已达到hostVocabSize上限,并继续添加新key + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_HOST_VOCAB_SIZE_TOO_SMALL); + free(newEmb); + + lookupKeys = { 0 }; + newEmb = nullptr; + // 传入embAddr为空指针 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_ADDRESS_NULL); + + // 更新存在于table的keys, 传入embAddr不为空指针 + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + // 线程数未超过核数 + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 4), H_OK); + free(newEmb); + + // 线程数等于核数 + uint32_t processCoreNum = std::thread::hardware_concurrency(); + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, processCoreNum), H_OK); + free(newEmb); + + // 线程数大于核数 + processCoreNum = std::thread::hardware_concurrency(); + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, processCoreNum + 1), H_THREAD_NUM_ERROR); + free(newEmb); + + // 线程数为0 + processCoreNum = std::thread::hardware_concurrency(); + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 0), H_THREAD_NUM_ERROR); + free(newEmb); + + // 线程数为1 + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 1), H_OK); + free(newEmb); + + // lookupkeys为空 + lookupKeys = {}; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 1), H_OK); + free(newEmb); + + TearDown(); + + // 更新不存在于table的key,且当前embLocalTable中存储的key未达到hostVocabSize上限,继续添加新key + tableName = "test_table_one"; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + lookupKeys = { 0, 1 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb); + free(newEmb); + lookupKeys = { 2, 3 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE end============="); +} + +TEST_F(EmbCacheTest, GetSwapPairsAndKey2Offset) +{ + CTRLog(CTRLogLevel::INFO, "===========GetSwapPairsAndKey2Offset start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 100; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 10; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector insertKeys; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + + // 使用不存在的table + insertKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset("not_a_table", insertKeys, swapInKoPair, swapOutKoPair), + H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tooLongTableName, insertKeys, swapInKoPair, swapOutKoPair), + H_TABLE_NAME_TOO_LONG); + + // 正常查找不存在的keys + insertKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair, swapOutKoPair), H_OK); + bool ret1 = true; + for (uint64_t i = 0; i < swapInKoPair.first.size(); i++) { + if (swapInKoPair.first[i] != i) { + string msg = "the " + std::to_string(i) + "th has key " + std::to_string(swapInKoPair.first[i]) + + ", but expect " + std::to_string(i); + CTRLog(CTRLogLevel::INFO, msg.c_str()); + ret1 = false; + } + } + ASSERT_EQ(ret1, true); + + // 正常查找存在的keys + std::pair, std::vector> swapInKoPair2, swapOutKoPair2; + insertKeys = { 1, 2, 3 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair2, swapOutKoPair2), H_OK); + uint64_t uint_zero = 0; + ASSERT_EQ(swapInKoPair2.first.size(), uint_zero); + + std::pair, std::vector> swapInKoPair3, swapOutKoPair3; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + // 使用非空的koPair + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair, swapOutKoPair3), + H_ARG_NOT_EMPTY); + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair3, swapInKoPair), H_ARG_NOT_EMPTY); + // 存入keys正好达到maxCacheSize上限值 + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair3, swapOutKoPair3), H_OK); + + // 存入keys正好越过到maxCacheSize上限值 + std::pair, std::vector> swapInKoPair4, swapOutKoPair4; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair4, swapOutKoPair4), + H_MAX_CACHESIZE_TOO_SMALL); + + embCache->Destroy(); + // 单次存入keys超过maxCacheSize上限值 + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair5, swapOutKoPair5; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair5, swapOutKoPair5), + H_MAX_CACHESIZE_TOO_SMALL); + + embCache->Destroy(); + // 单次存入keys正好达到上限值后,再次查找已存在的keys + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair6, swapOutKoPair6; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair6, swapOutKoPair6), H_OK); + + embCache->Destroy(); + // 连续两次存入的keys未超过上限,第三次传入keys达到上限 + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair7, swapOutKoPair7; + insertKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair7, swapOutKoPair7), H_OK); + + std::pair, std::vector> swapInKoPair8, swapOutKoPair8; + insertKeys = { 5, 6, 7, 8, 9 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair8, swapOutKoPair8), H_OK); + + std::pair, std::vector> swapInKoPair9, swapOutKoPair9; + insertKeys = { 10, 11, 12, 13, 14 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair9, swapOutKoPair9), H_OK); + + embCache->Destroy(); + // 查询INVALID_KEY + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair10, swapOutKoPair10; + uint64_t neg_one = -1; + insertKeys = { neg_one, neg_one, neg_one, neg_one, neg_one }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair10, swapOutKoPair10), H_OK); + ASSERT_EQ(swapInKoPair10.first.empty(), true); + ASSERT_EQ(swapInKoPair10.second.empty(), true); + ASSERT_EQ(swapOutKoPair10.first.empty(), true); + ASSERT_EQ(swapOutKoPair10.second.empty(), true); + + // 查找空keys + std::pair, std::vector> swapInKoPair11, swapOutKoPair11; + insertKeys = {}; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair11, swapOutKoPair11), H_OK); + ASSERT_EQ(swapInKoPair11.first.empty(), true); + ASSERT_EQ(swapInKoPair11.second.empty(), true); + ASSERT_EQ(swapOutKoPair11.first.empty(), true); + ASSERT_EQ(swapOutKoPair11.second.empty(), true); + CTRLog(CTRLogLevel::INFO, "===========GetSwapPairsAndKey2Offset end============="); +} + + +bool checkKeys(std::set &keySet, std::vector> &historyKeyVec, + const std::vector &keys, const std::vector &swapInKeys, + const std::vector &swapOutKeys, uint32_t maxCacheSize) +{ + std::set newKeys; + for (auto key : keys) { + if (keySet.find(key) == keySet.end()) { + newKeys.insert(key); + } + keySet.insert(key); + } + for (auto key : swapInKeys) { + if (newKeys.find(key) == newKeys.end()) { + CTRLog(CTRLogLevel::ERROR, "swapIn key error1"); + return false; + } + } + if (swapInKeys.size() != newKeys.size()) { + CTRLog(CTRLogLevel::ERROR, "swapIn key error2"); + return false; + } + historyKeyVec.insert(historyKeyVec.begin(), { keys.begin(), keys.end() }); + if (historyKeyVec.size() > 2) { + historyKeyVec.pop_back(); + } + for (auto key : swapOutKeys) { + if (historyKeyVec[0].find(key) != historyKeyVec[0].end() || + historyKeyVec[1].find(key) != historyKeyVec[1].end()) { + CTRLog(CTRLogLevel::ERROR, "swapOut key error1"); + return false; + } + } + for (auto key : swapOutKeys) { + if (keySet.find(key) == keySet.end()) { + CTRLog(CTRLogLevel::ERROR, "swapOut key error2"); + return false; + } + } + for (auto key : swapOutKeys) { + keySet.erase(key); + } + if (keySet.size() > maxCacheSize) { + CTRLog(CTRLogLevel::ERROR, "total key size error"); + return false; + } + return true; +} + +bool checkOffsets(std::set &offsetSet, const std::vector &swapInOffsets, + const std::vector &swapOutOffset) +{ + for (auto offset : swapOutOffset) { + if (offsetSet.find(offset) == offsetSet.end()) { + CTRLog(CTRLogLevel::ERROR, "swapOut offset error1"); + return false; + } + } + + for (auto offset : swapOutOffset) { + offsetSet.erase(offset); + } + + for (auto offset : swapInOffsets) { + if (offsetSet.find(offset) != offsetSet.end()) { + CTRLog(CTRLogLevel::ERROR, "swapIn offset error"); + return false; + } + offsetSet.insert(offset); + } + + return true; +} + + +TEST_F(EmbCacheTest, DEVICE_COMBINE_TEST) +{ + CTRLog(CTRLogLevel::INFO, "===========DEVICE_COMBINE_TEST start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 4000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 30000; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::set keySet; + std::set offsetSet; + std::vector> historyKeyVec; + std::vector> historyOffsetVec; + std::vector lookupKeys; + std::vector check_keys; + for (uint32_t i = 0; i < 50; i++) { + lookupKeys = GenKeys(10000, 123 + i, 0, 100000); + check_keys = lookupKeys; + std::pair, std::vector> koPair1; + std::pair, std::vector> koPair2; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, koPair1, koPair2), H_OK); + bool retKey1 = checkKeys(keySet, historyKeyVec, check_keys, koPair1.first, koPair2.first, devVocabSize); + bool retOffset1 = checkOffsets(offsetSet, koPair1.second, koPair2.second); + ASSERT_EQ(retKey1, true); + ASSERT_EQ(retOffset1, true); + } + + CTRLog(CTRLogLevel::INFO, "===========DEVICE_COMBINE_TEST end============="); +} + +TEST_F(EmbCacheTest, REMOVE_KEYS) +{ + CTRLog(CTRLogLevel::INFO, "===========REMOVE_KEYS start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 100; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 10; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + std::vector removeKeys; + float *addr; + float *newEmb; + + for (uint32_t i = 0; i < hostVocabSize - 1; i++) { + lookupKeys.emplace_back(i); + for (uint32_t j = 0; j < hostVocabSize - 1; j++) { + removeKeys.emplace_back(i + j); + } + } + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + // 表存在 + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + + // 表不存在 + ASSERT_EQ(embCache->RemoveEmbsByKeys("not_a_table", lookupKeys), H_TABLE_NOT_EXIST); + + // 表名超过上限 + ASSERT_EQ(embCache->RemoveEmbsByKeys(tooLongTableName, lookupKeys), H_TABLE_NAME_TOO_LONG); + + // remove INVALID_KEY + uint64_t neg_one = -1; + lookupKeys = { neg_one, neg_one, neg_one, neg_one, neg_one }; + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + + // 判断embLocalTable是否remove掉记录信息 + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 999.99f; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret1 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) > 0.0000001) { + ret1 = false; + } + } + free(addr); + ASSERT_EQ(ret1, true); + + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret2 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) <= 0.0000001) { + ret2 = false; + } + } + free(addr); + ASSERT_EQ(ret2, true); + + // 判断offsetMapper是否remove掉记录信息 + lookupKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair, swapOutKoPair), H_OK); + removeKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, removeKeys), H_OK); + std::vector> koVec; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tableName, koVec), H_OK); + bool ret3 = true; + for (uint32_t i = 0; i < koVec.size(); i++) { + if (std::find(removeKeys.begin(), removeKeys.end(), koVec[i].first) != removeKeys.end()) { + ret3 = false; + } + } + ASSERT_EQ(ret3, true); + // 判断删除后,还能再添加 + lookupKeys = { 9, 10, 11, 12, 13 }; + std::vector oldKeys = lookupKeys; + std::pair, std::vector> swapInKoPair2, swapOutKoPair2; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair2, swapOutKoPair2), H_OK); + bool ret4 = true; + for (uint32_t i = 0; i < 5; i++) { + if (oldKeys[i] != swapInKoPair2.first[i]) { + ret4 = false; + } + } + bool ret5 = true; + for (uint32_t i = 0; i < 5; i++) { + if (lookupKeys[i] != swapInKoPair2.second[i]) { + ret5 = false; + } + } + ASSERT_EQ(ret4, true); + ASSERT_EQ(ret5, true); + ASSERT_EQ(swapInKoPair2.first.size(), 5ull); + ASSERT_EQ(swapInKoPair2.second.size(), 5ull); + ASSERT_EQ(swapOutKoPair2.first.empty(), true); + ASSERT_EQ(swapOutKoPair2.second.empty(), true); + + removeKeys = { 9, 10, 11, 3 }; + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, removeKeys), H_OK); + std::vector> koVec2; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tableName, koVec2), H_OK); + bool ret6 = true; + for (uint32_t i = 0; i < koVec2.size(); i++) { + if (std::find(removeKeys.begin(), removeKeys.end(), koVec2[i].first) != removeKeys.end()) { + ret6 = false; + } + } + ASSERT_EQ(ret6, true); + + // 判断删除后,还能再添加 + lookupKeys = { 0, 1, 2, 3, 4, 5, 6, 7 }; + std::vector oldKeys2 = lookupKeys; + std::pair, std::vector> swapInKoPair3, swapOutKoPair3; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair3, swapOutKoPair3), H_OK); + bool ret7 = true; + for (uint32_t i = 0; i < 8; i++) { + if (oldKeys2[i] != swapInKoPair3.first[i]) { + ret7 = false; + } + } + bool ret8 = true; + for (uint32_t i = 0; i < 8; i++) { + if (lookupKeys[i] != swapInKoPair3.second[i]) { + ret8 = false; + } + } + ASSERT_EQ(ret7, true); + ASSERT_EQ(ret8, true); + ASSERT_EQ(swapInKoPair3.first.size(), 8ull); + ASSERT_EQ(swapInKoPair3.second.size(), 8ull); + ASSERT_EQ(swapOutKoPair3.first.empty(), true); + ASSERT_EQ(swapOutKoPair3.second.empty(), true); + + lookupKeys = { 15 }; + std::pair, std::vector> swapInKoPair4, swapOutKoPair4; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair4, swapOutKoPair4), + H_OK); + + CTRLog(CTRLogLevel::INFO, "===========REMOVE_KEYS end============="); +} + +TEST_F(EmbCacheTest, ExportDeviceKeyOffsetPairs) +{ + CTRLog(CTRLogLevel::INFO, "===========ExportDeviceKeyOffsetPairs start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 10; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 8; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + // 使用不存在的table名字 + std::vector> koVec; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs("not_a_table", koVec), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tooLongTableName, koVec), H_TABLE_NAME_TOO_LONG); + + // 正常export出koPair + std::vector lookupKeys; + std::vector checkKeys; + lookupKeys = { 6, 0, 8, 1, 3, 4 }; + checkKeys = lookupKeys; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair, swapOutKoPair), H_OK); + std::vector> koVec2; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tableName, koVec2), H_OK); + ASSERT_EQ(koVec2.size(), lookupKeys.size()); + bool ret1 = true; + for (uint32_t i = 0; i < koVec2.size(); i++) { + if (koVec2[i].first != checkKeys[i] || koVec2[i].second != lookupKeys[i]) { + ret1 = false; + } + } + ASSERT_EQ(ret1, true); + + CTRLog(CTRLogLevel::INFO, "===========ExportDeviceKeyOffsetPairs end============="); +} + +TEST_F(EmbCacheTest, GetEmbTableNames) +{ + CTRLog(CTRLogLevel::INFO, "===========GetEmbTableNames start============="); + factory->CreateEmbCacheManager(embCache); + uint32_t hostVocabSize = 10; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 8; + std::vector tableNameVec; + tableNameVec.emplace_back("table1"); + tableNameVec.emplace_back("table2"); + tableNameVec.emplace_back("table3"); + for (const std::string tableName : tableNameVec) { + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.5, 0, 1.0); + std::string normalInitializeName = "random_normal_initializer"; + EmbCache::InitializerInfo normalInitializeInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + std::string constantInitializeName = "constant_initializer"; + EmbCache::InitializerInfo constantInitializeInfo(constantInitializeName, embeddingSize, embeddingSize, + constantInitializerInfo); + + std::vector initializeInfos(extEmbeddingSize / embeddingSize); + initializeInfos[0] = normalInitializeInfo; + for (uint64_t i = 1; i < initializeInfos.size(); i++) { + initializeInfos[i] = constantInitializeInfo; + } + embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize); + } + std::vector allTableNames; + std::vector notEmptyVector = { "123" }; + ASSERT_EQ(embCache->GetEmbTableNames(notEmptyVector), H_ARG_NOT_EMPTY); + + ASSERT_EQ(embCache->GetEmbTableNames(allTableNames), H_OK); + bool ret1 = true; + for (auto tableName : allTableNames) { + if (std::find(tableNameVec.begin(), tableNameVec.end(), tableName) == tableNameVec.end()) { + ret1 = false; + } + } + for (auto tableName : tableNameVec) { + if (std::find(allTableNames.begin(), allTableNames.end(), tableName) == allTableNames.end()) { + ret1 = false; + } + } + ASSERT_EQ(ret1, true); + + CTRLog(CTRLogLevel::INFO, "===========GetEmbTableNames end============="); +} + +TEST_F(EmbCacheTest, SERIALIZE) +{ + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + + lookupKeys = { 0 }; + std::vector buffer; + ASSERT_EQ(embCache->Serialize("not_a_table", buffer), H_TABLE_NOT_EXIST); + // 表名超过上限 + ASSERT_EQ(embCache->Serialize(tooLongTableName, buffer), H_TABLE_NAME_TOO_LONG); + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE end============="); +} + +TEST_F(EmbCacheTest, DESERIALIZE) +{ + CTRLog(CTRLogLevel::INFO, "===========DESERIALIZE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + + lookupKeys = { 0 }; + std::vector buffer = { 'A', 'B', '1', '2' }; + ASSERT_EQ(embCache->Deserialize("not_a_table", buffer), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->Deserialize(tooLongTableName, buffer), H_TABLE_NAME_TOO_LONG); + + ASSERT_EQ(embCache->Deserialize(tableName, buffer), H_LOAD_ERROR); + + lookupKeys = { 0, 1, 2, 3, 4 }; + float *newEmb; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + std::vector buffer1; + ASSERT_EQ(embCache->Serialize(tableName, buffer1), H_OK); + buffer1.erase(buffer1.begin() + buffer1.size() / 2, buffer1.end()); + ASSERT_EQ(embCache->Deserialize(tableName, buffer1), H_LOAD_ERROR); + + CTRLog(CTRLogLevel::INFO, "===========DESERIALIZE end============="); +} + +TEST_F(EmbCacheTest, SERIALIZE_DESERIALIZE) +{ + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE_DESERIALIZE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + lookupKeys = { 0, 1, 2, 3, 4 }; + float *newEmb; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + std::vector buffer1; + std::vector buffer2; + + ASSERT_EQ(embCache->Serialize(tableName, buffer1), H_OK); + ASSERT_EQ(embCache->Deserialize(tableName, buffer1), H_OK); + ASSERT_EQ(embCache->Serialize(tableName, buffer2), H_OK); + ASSERT_EQ(buffer1.size(), buffer2.size()); + for (uint64_t i = 0; i < buffer1.size(); i++) { + ASSERT_EQ(buffer1[i], buffer2[i]); + } + ASSERT_EQ(buffer1, buffer2); + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE_DESERIALIZE end============="); +} + +TEST_F(EmbCacheTest, ERROR_INITIALIZER) +{ + CTRLog(CTRLogLevel::INFO, "===========ERROR_INITIALIZER start============="); + uint32_t embeddingSize = 13; + /* 对ConstantInitializerInfo的constValue和initK的校验 */ + std::string constantInitializeName = "constant_initializer"; + // 日志打印"constant value is less than -1000000000, and will use -1000000000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo1(-1e9 - 1e8, 1.0); + EmbCache::InitializerInfo constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo1); + + // 日志打印"constant value is greater than 1000000000, and will use 1000000000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo2(1e9 + 1e8, 1.0); + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo2); + + // 日志打印"constant initK is greater than 10000, and will use 10000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo3(0.233, 10001); + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo3); + + // 日志打印"constant initK is less than -10000, and will use -10000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo4(0.233, -10001); + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo4); + + /* 对NormalIntializerInfo的mean、stdev和initK的校验 */ + std::string normalInitializeName = "random_normal_initializer"; + // 日志打印"random normal mean param is greater than 1000000000, and will use + // 1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo1(1e9 + 1e8, 0.05, 0, 1.0); + EmbCache::InitializerInfo normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo1); + + // 日志打印"random normal mean param is less than -1000000000, and will use + // -1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo2(-1e9 - 1e8, 0.05, 0, 1.0); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo2); + + // 日志打印"random normal stddev param is greater than 100, and will use 100.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo3(0, 101, 0, 1.0); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo3); + + // 日志打印"random normal stddev param is less than 0, and will use 0.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo4(0, -1, 0, 1.0); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo4); + // 日志打印"random normal initK is greater than 10000, and will use 10000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo5(0, 0.05, 0, 10001); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo5); + // 日志打印"random normal initK is less than -10000, and will use -10000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo6(0, 0.05, 0, -10001); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo6); + + /* 对TruncatedNormalInitializer的mean、stdev以及initK的校验 */ + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + // 日志打印"truncated normal mean param is greater than 1000000000, and will use + // 1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo7(1e9 + 1e8, 0.05, 0, 1.0); + EmbCache::InitializerInfo truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo7); + + // 日志打印"truncated normal mean param is less than -1000000000, and will use + // -1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo8(-1e9 - 1e8, 0.05, 0, 1.0); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo8); + + // 日志打印"truncated normal stddev param is greater than 100, and will use 100.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo9(0, 101, 0, 1.0); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo9); + + // 日志打印"truncated normal stddev param is less than 0.000000, and will use 0.000000."并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo10(0, -1, 0, 1.0); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo10); + // 日志打印"truncated normal initK is greater than 10000, and will use 10000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo11(0, 0.05, 0, 10001); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo11); + // 日志打印"truncated normal initK is less than -10000, and will use -10000." + EmbCache::NormalInitializerInfo normalInitializerInfo12(0, 0.05, 0, -10001); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo12); + CTRLog(CTRLogLevel::INFO, "===========ERROR_INITIALIZER end============="); +} + + +TEST_F(EmbCacheTest, EmbeddingRemove) +{ + CTRLog(CTRLogLevel::INFO, "===========EmbeddingRemove start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 100; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 100; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + std::vector removeKeys; + float *addr; + float *newEmb; + + for (uint32_t i = 0; i < hostVocabSize - 1; i++) { + lookupKeys.emplace_back(i); + for (uint32_t j = 0; j < hostVocabSize - 1; j++) { + removeKeys.emplace_back(i + j); + } + } + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 表存在 + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys), H_OK); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 单线程 + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys, 1), H_OK); + + free(addr); + // REMOVE空keys + std::vector emptyRemoveKeys; + ASSERT_EQ(embCache->EmbeddingRemove(tableName, emptyRemoveKeys), H_OK); + + // 表不存在 + ASSERT_EQ(embCache->EmbeddingRemove("not_a_table", lookupKeys), H_TABLE_NOT_EXIST); + // 表名超过上限 + ASSERT_EQ(embCache->EmbeddingRemove(tooLongTableName, lookupKeys), H_TABLE_NAME_TOO_LONG); + + // remove INVALID_KEY + uint64_t neg_one = -1; + lookupKeys = { neg_one, neg_one, neg_one, neg_one, neg_one }; + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys), H_OK); + + // 判断embLocalTable是否remove掉记录信息 + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 999.99f; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret1 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) > 0.0000001) { + ret1 = false; + } + } + free(addr); + ASSERT_EQ(ret1, true); + + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys), H_OK); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret2 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) <= 0.0000001) { + ret2 = false; + } + } + free(addr); + ASSERT_EQ(ret2, true); + + // 判断offsetMapper是否remove掉记录信息 + lookupKeys = { 6, 0, 8, 1, 3, 4 }; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair, swapOutKoPair), H_OK); + removeKeys = { 0, 1, 4 }; + ASSERT_EQ(embCache->EmbeddingRemove(tableName, removeKeys), H_OK); + + CTRLog(CTRLogLevel::INFO, "===========EmbeddingRemove end============="); +} + +TEST_F(EmbCacheTest, GET_EMB_TABLE_INFO) +{ + CTRLog(CTRLogLevel::INFO, "===========GET_EMB_TABLE_INFO start============="); + std::string tableName = "test_table"; + uint64_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint64_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + lookupKeys = { 0, 1, 2, 3, 4 }; + float *newEmb; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + std::vector keys; + std::vector> embeddings; + std::vector> optimizerSlots; + + ASSERT_EQ(embCache->GetEmbTableInfos("Invalid_table_name", keys, embeddings, optimizerSlots), H_TABLE_NOT_EXIST); + ASSERT_EQ(embCache->GetEmbTableInfos(tooLongTableName, keys, embeddings, optimizerSlots), H_TABLE_NAME_TOO_LONG); + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys, embeddings, optimizerSlots), H_OK); + bool ret = true; + if (keys.size() != 5) { + ret = false; + } + uint32_t optimizerSlotSize = extEmbeddingSize - embeddingSize; + for (auto key : keys) { + auto it = std::find(lookupKeys.begin(), lookupKeys.end(), key); + if (it == lookupKeys.end()) { + ret = false; + break; + } + uint32_t index = it - lookupKeys.begin(); + for (uint32_t i = 0; i < embeddingSize; i++) { + if (fabs(embeddings[index][i] - 0.01f * (i + index * extEmbeddingSize)) > 0.0000001) { + ret = false; + } + } + for (uint32_t i = 0; i < optimizerSlotSize; i++) { + if (fabs(optimizerSlots[index][i] - 0.01f * (i + index * extEmbeddingSize + embeddingSize)) > 0.0000001) { + ret = false; + } + } + } + ASSERT_EQ(ret, true); + + std::vector keys2 = { 1, 2, 3 }; + std::vector> embeddings2; + std::vector> optimizerSlots2; + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys2, embeddings2, optimizerSlots2), H_ARG_NOT_EMPTY); + + std::vector keys3; + std::vector> embeddings3; + std::vector> optimizerSlots3; + embeddings3.emplace_back(std::vector({ 0.1f, 0.2f })); + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys3, embeddings3, optimizerSlots3), H_ARG_NOT_EMPTY); + + std::vector keys4; + std::vector> embeddings4; + std::vector> optimizerSlots4; + optimizerSlots4.emplace_back(std::vector({ 0.1f, 0.2f })); + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys4, embeddings4, optimizerSlots4), H_ARG_NOT_EMPTY); + embCache->Destroy(); + + hostVocabSize = 5; + embeddingSize = 13; + extEmbeddingSize = 13; + devVocabSize = 2; + + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys2; + lookupKeys2 = { 0, 1, 2, 3, 4 }; + float *newEmb2; + newEmb2 = (float *)malloc(lookupKeys2.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys2.size() * extEmbeddingSize; i++) { + newEmb2[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys2, newEmb2), H_OK); + free(newEmb2); + + std::vector keys5; + std::vector> embeddings5; + std::vector> optimizerSlots5; + + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys5, embeddings5, optimizerSlots5), H_OK); + bool ret2 = true; + if (keys.size() != 5) { + ret2 = false; + } + for (auto key : keys) { + auto it = std::find(lookupKeys2.begin(), lookupKeys2.end(), key); + if (it == lookupKeys2.end()) { + ret2 = false; + break; + } + uint32_t index = it - lookupKeys2.begin(); + for (uint32_t i = 0; i < embeddingSize; i++) { + if (fabs(embeddings5[index][i] - 0.01f * (i + index * extEmbeddingSize)) > 0.0000001) { + ret2 = false; + } + } + } + if (!optimizerSlots5.empty()) { + ret2 = false; + } + ASSERT_EQ(ret2, true); + + CTRLog(CTRLogLevel::INFO, "===========GET_EMB_TABLE_INFO end============="); +} + +TEST_F(EmbCacheTest, LOAD_EMB_TABLE_INFO) +{ + CTRLog(CTRLogLevel::INFO, "===========LOAD_EMB_TABLE_INFO start============="); + std::string tableName = "test_table"; + uint64_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint64_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector keys; + std::vector> embeddings; + std::vector> optimizerSlots; + + keys = { 0, 1, 2, 3, 4 }; + for (uint64_t i = 0; i < keys.size(); i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings.emplace_back(curEmbedding); + } + uint32_t optimizerSlotSize = extEmbeddingSize - embeddingSize; + for (uint64_t i = 0; i < keys.size(); i++) { + std::vector curOptimizerSlot; + for (uint64_t j = 0; j < optimizerSlotSize; j++) { + curOptimizerSlot.emplace_back(0.01f * (i * extEmbeddingSize + embeddingSize + j)); + } + optimizerSlots.emplace_back(curOptimizerSlot); + } + ASSERT_EQ(embCache->LoadEmbTableInfos("Invalid_table_name", keys, embeddings, optimizerSlots), H_TABLE_NOT_EXIST); + ASSERT_EQ(embCache->LoadEmbTableInfos(tooLongTableName, keys, embeddings, optimizerSlots), H_TABLE_NAME_TOO_LONG); + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys, embeddings, optimizerSlots), H_OK); + + std::vector keys2; + std::vector> embeddings2; + std::vector> optimizerSlots2; + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys2, embeddings2, optimizerSlots2), H_OK); + + bool ret = true; + if (keys2.size() != 5) { + ret = false; + } + for (auto key : keys2) { + auto it = std::find(keys.begin(), keys.end(), key); + if (it == keys.end()) { + ret = false; + break; + } + uint32_t index = it - keys.begin(); + for (uint32_t i = 0; i < embeddingSize; i++) { + if (fabs(embeddings2[index][i] - 0.01f * (i + index * extEmbeddingSize)) > 0.0000001) { + ret = false; + } + } + for (uint32_t i = 0; i < optimizerSlotSize; i++) { + if (fabs(optimizerSlots2[index][i] - 0.01f * (i + index * extEmbeddingSize + embeddingSize)) > 0.0000001) { + ret = false; + } + } + } + ASSERT_EQ(ret, true); + + std::vector keys3; + std::vector> embeddings3; + std::vector> optimizerSlots3; + + keys3 = { 0, 1, 2, 3, 4 }; + for (uint64_t i = 0; i < keys3.size() - 1; i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings3.emplace_back(curEmbedding); + } + for (uint64_t i = 0; i < keys3.size(); i++) { + std::vector curOptimizerSlot; + for (uint64_t j = 0; j < optimizerSlotSize; j++) { + curOptimizerSlot.emplace_back(0.01f * (i * extEmbeddingSize + embeddingSize + j)); + } + optimizerSlots3.emplace_back(curOptimizerSlot); + } + // keys num != embeddings num + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys3, embeddings3, optimizerSlots3), H_LOAD_ERROR); + + std::vector keys4; + std::vector> embeddings4; + std::vector> optimizerSlots4; + + keys4 = { 0, 1, 2, 3, 4 }; + for (uint64_t i = 0; i < keys4.size(); i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings4.emplace_back(curEmbedding); + } + for (uint64_t i = 0; i < keys4.size() - 1; i++) { + std::vector curOptimizerSlot; + for (uint64_t j = 0; j < optimizerSlotSize; j++) { + curOptimizerSlot.emplace_back(0.01f * (i * extEmbeddingSize + embeddingSize + j)); + } + optimizerSlots4.emplace_back(curOptimizerSlot); + } + // keys num != optimizerSlots num + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys4, embeddings4, optimizerSlots4), H_LOAD_ERROR); + + std::vector keys5; + std::vector> embeddings5; + std::vector> optimizerSlots5; + + keys5 = { 0, 1, 2, 3, 4, 5 }; + for (uint64_t i = 0; i < keys5.size(); i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings5.emplace_back(curEmbedding); + } + for (uint64_t i = 0; i < keys5.size(); i++) { + std::vector curOptimizerSlot; + for (uint64_t j = 0; j < optimizerSlotSize; j++) { + curOptimizerSlot.emplace_back(0.01f * (i * extEmbeddingSize + embeddingSize + j)); + } + optimizerSlots5.emplace_back(curOptimizerSlot); + } + // loadKeys num > hostVocabSize + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys5, embeddings5, optimizerSlots5), H_LOAD_ERROR); + + std::vector keys6; + std::vector> embeddings6; + std::vector> optimizerSlots6; + + keys6 = { 0, 1, 2, 3, 4 }; + for (uint64_t i = 0; i < keys6.size(); i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize - 1; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings6.emplace_back(curEmbedding); + } + for (uint64_t i = 0; i < keys6.size(); i++) { + std::vector curOptimizerSlot; + for (uint64_t j = 0; j < optimizerSlotSize; j++) { + curOptimizerSlot.emplace_back(0.01f * (i * extEmbeddingSize + embeddingSize + j)); + } + optimizerSlots6.emplace_back(curOptimizerSlot); + } + // entering embeddingSize != table embeddingSize + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys6, embeddings6, optimizerSlots6), H_LOAD_ERROR); + + std::vector keys7; + std::vector> embeddings7; + std::vector> optimizerSlots7; + + keys7 = { 0, 1, 2, 3, 4 }; + for (uint64_t i = 0; i < keys7.size(); i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings7.emplace_back(curEmbedding); + } + for (uint64_t i = 0; i < keys7.size(); i++) { + std::vector curOptimizerSlot; + for (uint64_t j = 0; j < optimizerSlotSize - 1; j++) { + curOptimizerSlot.emplace_back(0.01f * (i * extEmbeddingSize + embeddingSize + j)); + } + optimizerSlots7.emplace_back(curOptimizerSlot); + } + // entering optimizerSlotSize != table optimizerSlotSize + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys7, embeddings7, optimizerSlots7), H_LOAD_ERROR); + embCache->Destroy(); + + hostVocabSize = 5; + embeddingSize = 13; + extEmbeddingSize = 13; + devVocabSize = 2; + + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector keys8; + std::vector> embeddings8; + std::vector> optimizerSlots8; + + keys8 = { 0, 1, 2, 3, 4 }; + for (uint64_t i = 0; i < keys8.size(); i++) { + std::vector curEmbedding; + for (uint64_t j = 0; j < embeddingSize; j++) { + curEmbedding.emplace_back(0.01f * (i * extEmbeddingSize + j)); + } + embeddings8.emplace_back(curEmbedding); + } + + ASSERT_EQ(embCache->LoadEmbTableInfos(tableName, keys8, embeddings8, optimizerSlots8), H_OK); + + std::vector keys9; + std::vector> embeddings9; + std::vector> optimizerSlots9; + ASSERT_EQ(embCache->GetEmbTableInfos(tableName, keys9, embeddings9, optimizerSlots9), H_OK); + + double eps = 0.0000001; + bool ret2 = true; + if (keys9.size() != 5) { + ret2 = false; + } + for (auto key : keys9) { + auto it = std::find(keys9.begin(), keys9.end(), key); + if (it == keys9.end()) { + ret2 = false; + break; + } + uint32_t index = it - keys9.begin(); + for (uint32_t i = 0; i < embeddingSize; i++) { + if (fabs(embeddings9[index][i] - 0.01f * (i + index * extEmbeddingSize)) > eps) { + ret2 = false; + } + } + } + if (!optimizerSlots9.empty()) { + ret2 = false; + } + ASSERT_EQ(ret2, true); + + CTRLog(CTRLogLevel::INFO, "===========LOAD_EMB_TABLE_INFO end============="); +} diff --git a/src/AccCTR/tests/ut/src/emb_cache_test.h b/src/AccCTR/tests/ut/src/emb_cache_test.h new file mode 100644 index 0000000000000000000000000000000000000000..5c87237ba1cf63df43164cdd7ccf6fae0379909f --- /dev/null +++ b/src/AccCTR/tests/ut/src/emb_cache_test.h @@ -0,0 +1,62 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef CTR_EMB_CACHE_TEST_H +#define CTR_EMB_CACHE_TEST_H + +#include +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "factory.h" +#include "embedding_cache.h" + + +class EmbCacheTest : public testing::Test { +protected: + EmbCacheTest(){}; + ~EmbCacheTest(){}; + static void SetUpTestCase(); + static void TearDownTestCase(); + + + void SetUp() override; + + void TearDown() override; + + static ock::ctr::EmbCacheManagerPtr SimpleCreateTable(std::string tableName, uint32_t hostVocabSize, uint32_t embeddingSize, + uint32_t extEmbeddingSize, uint32_t devVocabSize, std::pair normalPara = { 0, 0.05 }, + float constPara = 0.233); + + static ock::ctr::EmbCacheManagerPtr ConstZeroCreateTable(std::string tableName, uint32_t hostVocabSize, + uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, uint64_t prefillBufferSize = 50000, + uint8_t prefillThreadNum = 1); + + std::string tooLongTableName = + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100012"; +}; + +#endif // CTR_EMB_CACHE_TEST_H diff --git a/src/AccCTR/tests/ut/src/unique_test.cpp b/src/AccCTR/tests/ut/src/unique_test.cpp index ef6846f8077fe843271e6f17bacea37c6f2f100c..df5950e1e9dbe611b726c415acbb4f6ccf68facf 100644 --- a/src/AccCTR/tests/ut/src/unique_test.cpp +++ b/src/AccCTR/tests/ut/src/unique_test.cpp @@ -11,12 +11,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "unique_test.h" #include #include -#include "unique_test.h" - -FactoryPtr factory; +#include "common.h" void UniqueTest::SetUpTestCase() { @@ -96,6 +95,13 @@ TEST_F(UniqueTest, Conf) ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 3); // idCntFill空指针 uniqueOut.idCntFill = idCntFill; ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 7); // padding长度过小 + + unique->UnInitialize(); + delete[] idCnt; + delete[] idCntFill; + delete[] uniqueIdCntInBucket; + delete[] uniqueIdInBucket; + std::cout << "===========Conf end=============" << std::endl; } @@ -116,6 +122,9 @@ TEST_F(UniqueTest, usePaddingNoShardingErr) conf.outputType = OutputType::ENHANCED; ASSERT_EQ(unique->Initialize(conf), 9); + + unique->UnInitialize(); + std::cout << "===========usePaddingNoShardingErr end=============" << std::endl; } @@ -133,6 +142,8 @@ TEST_F(UniqueTest, useNegativeDesiredSize) ASSERT_EQ(unique->Initialize(conf), 1); + unique->UnInitialize(); + std::cout << "===========useNegativeDesiredSize end=============" << std::endl; } @@ -144,7 +155,10 @@ TEST_F(UniqueTest, DoUniqueNormal) std::string input_path(path); std::cout << "input_path:" + input_path + "/data30.txt" << std::endl; std::ifstream input(input_path + "/data30.txt"); - + if (!input.good()) { + std::cout << "Failed to open file:" + input_path + "/data30.txt" << std::endl; + return; + } std::vector numbers; std::string line; while (std::getline(input, line, ',')) { @@ -156,6 +170,8 @@ TEST_F(UniqueTest, DoUniqueNormal) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.trace = true; conf.desiredSize = numbers.size(); @@ -203,6 +219,9 @@ TEST_F(UniqueTest, DoUniqueNormal) ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); unique->UnInitialize(); + if (path) { + free(path); + } std::cout << "===========DoUniqueNormal end=============" << std::endl; } @@ -213,6 +232,8 @@ TEST_F(UniqueTest, UseErrOutputTypeEnhanced) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -253,6 +274,8 @@ TEST_F(UniqueTest, UseErrOutputTypeNormal) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -292,6 +315,8 @@ TEST_F(UniqueTest, DoEnhancedUnique) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -340,6 +365,8 @@ TEST_F(UniqueTest, DoEnhancedUniqueErr) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -392,6 +419,9 @@ TEST_F(UniqueTest, DoEnhancedUniqueErr) ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); unique->UnInitialize(); + delete[] uniqueIdInBucket; + delete[] idCnt; + std::cout << "===========DoEnhancedUniqueErr end=============" << std::endl; } @@ -402,6 +432,8 @@ TEST_F(UniqueTest, DoEnhancedUnique_UniqueIdSize) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -449,6 +481,8 @@ TEST_F(UniqueTest, idCntIsNull) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -488,6 +522,8 @@ TEST_F(UniqueTest, idCntIsNullSharding) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -526,6 +562,9 @@ TEST_F(UniqueTest, idCntIsNullSharding) ASSERT_EQ(ret, 3); unique->UnInitialize(); + delete[] uniqueIdCntInBucket; + delete[] uniqueIdInBucket; + std::cout << "===========idCntIsNullSharding end=============" << std::endl; } @@ -537,6 +576,8 @@ TEST_F(UniqueTest, DoUniqueShard) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.useIdCount = true; @@ -600,6 +641,7 @@ TEST_F(UniqueTest, DoUniqueShard) ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); unique->UnInitialize(); + delete[] uniqueIdInBucket; std::cout << "===========DoUniqueShard end=============" << std::endl; } @@ -612,6 +654,8 @@ TEST_F(UniqueTest, DoUniqueOnlyShard) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -663,6 +707,7 @@ TEST_F(UniqueTest, DoUniqueOnlyShard) ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); unique->UnInitialize(); + delete[] uniqueIdInBucket; std::cout << "===========DoUniqueOnlyShard end=============" << std::endl; } @@ -675,6 +720,8 @@ TEST_F(UniqueTest, DoUniquePadding) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.usePadding = true; conf.useSharding = true; @@ -745,6 +792,8 @@ TEST_F(UniqueTest, DoUniquePadding) ASSERT_THAT(idCntFill, testing::ElementsAreArray(expectedIdCnt)); ASSERT_EQ(uniqueOut.uniqueIdCnt, conf.paddingSize * conf.shardingNum); unique->UnInitialize(); + delete[] idCnt; + delete[] uniqueIdInBucket; std::cout << "===========DoUniquePadding end=============" << std::endl; } @@ -755,6 +804,8 @@ TEST_F(UniqueTest, DoUniqueNoThreadPool) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 20; // 配置空间大于实际输入数组长度,验证正常运行 conf.dataType = DataType::INT64; @@ -817,6 +868,8 @@ TEST_F(UniqueTest, DoUniqueShardNumberOversize) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -885,6 +938,7 @@ TEST_F(UniqueTest, DoUniqueShardNumberOversize) ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); unique->UnInitialize(); + delete[] uniqueIdInBucket; std::cout << "===========DoUniqueShardNumberOversize end=============" << std::endl; } @@ -895,6 +949,7 @@ TEST_F(UniqueTest, DoUniqueSpecial) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); int count = 1000000; UniqueConf conf; @@ -952,6 +1007,12 @@ TEST_F(UniqueTest, DoUniqueSpecial) } unique->UnInitialize(); + delete[] uniqueData; + delete[] index; + delete[] idCnt; + delete[] idCntFill; + delete[] uniqueIdCntInBucket; + delete[] uniqueIdInBucket; std::cout << "===========DoUniqueSpecial end=============" << std::endl; } @@ -963,6 +1024,8 @@ TEST_F(UniqueTest, IdLarge) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -989,6 +1052,10 @@ TEST_F(UniqueTest, IdLarge) uniqueOut.idCnt = idCnt; ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 6); // ID太大 + + unique->UnInitialize(); + delete[] idCnt; + std::cout << "===========IdLarge end=============" << std::endl; } @@ -999,6 +1066,8 @@ TEST_F(UniqueTest, DoUniqueNormalInt32) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -1062,6 +1131,8 @@ TEST_F(UniqueTest, DoUniqueNormalInt32) ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); unique->UnInitialize(); + delete[] uniqueIdInBucket; + std::cout << "===========DoUniqueNormalInt32 end=============" << std::endl; } @@ -1122,6 +1193,8 @@ TEST_F(UniqueTest, DoUniqueShardMultipleTimes) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -1162,14 +1235,14 @@ TEST_F(UniqueTest, DoUniqueShardMultipleTimes) unordered_set uniqueIdSet; map expectedIdCntMap; - for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { - restoreIds[i] = uniqueId[index[i]]; - expectedIdCntMap[inputId[i]]++; - if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + for (size_t j = 0; j < uniqueIn.inputIdCnt; j++) { + restoreIds[j] = uniqueId[index[j]]; + expectedIdCntMap[inputId[j]]++; + if (uniqueIdSet.find(inputId[j]) != uniqueIdSet.end()) { continue; } else { - uniqueIdSet.insert(inputId[i]); - expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + uniqueIdSet.insert(inputId[j]); + expectedUniqueIdCnt[inputId[j] % conf.shardingNum]++; } } @@ -1177,13 +1250,14 @@ TEST_F(UniqueTest, DoUniqueShardMultipleTimes) int uniqueSum = 0; - for (int i = 0; i < conf.shardingNum; i++) { - uniqueSum += uniqueIdCntInBucket[i]; + for (int j = 0; j < conf.shardingNum; j++) { + uniqueSum += uniqueIdCntInBucket[j]; } vector expectedIdCnt(uniqueSum); - for (int i = 0; i < uniqueSum; i++) { - expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + + for (int j = 0; j < uniqueSum; j++) { + expectedIdCnt[j] = expectedIdCntMap[uniqueId[j]]; } expectedIdCnt.resize(uniqueIn.inputIdCnt); @@ -1192,6 +1266,7 @@ TEST_F(UniqueTest, DoUniqueShardMultipleTimes) ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); } unique->UnInitialize(); + delete[] uniqueIdInBucket; std::cout << "===========DoUniqueShardMultipleTimes end=============" << std::endl; } @@ -1276,6 +1351,9 @@ TEST_F(UniqueTest, DoUniquePaddingMultipleTimes) } unique->UnInitialize(); + delete[] idCnt; + delete[] uniqueIdInBucket; + std::cout << "===========DoUniquePaddingMultipleTimes end=============" << std::endl; } @@ -1285,6 +1363,8 @@ TEST_F(UniqueTest, IdCntSmall) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -1310,6 +1390,10 @@ TEST_F(UniqueTest, IdCntSmall) uniqueOut.idCnt = idCnt; ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 4); // idcnt过小 + + unique->UnInitialize(); + delete[] idCnt; + std::cout << "===========IdCntSmall end=============" << std::endl; } @@ -1320,7 +1404,10 @@ TEST_F(UniqueTest, DoUniqueLotsDataFunction) std::string input_path(path); std::cout << "input_path:" + input_path + "/data40.txt" << std::endl; std::ifstream input(input_path + "/data40.txt"); - + if (!input.good()) { + std::cout << "Failed to open file:" + input_path + "/data40.txt" << std::endl; + return; + } std::vector numbers; std::string line; while (std::getline(input, line, ',')) { @@ -1408,6 +1495,7 @@ TEST_F(UniqueTest, DoUniqueLotsDataFunction) ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); unique->UnInitialize(); + delete[] uniqueIdInBucket; if (path) { free(path); } @@ -1422,7 +1510,10 @@ TEST_F(UniqueTest, DoUniqueLotsDataPaddingFunction) std::string input_path(path); std::cout << "input_path:" + input_path + "/data30.txt" << std::endl; std::ifstream input(input_path + "/data30.txt"); - + if (!input.good()) { + std::cout << "Failed to open file:" + input_path + "/data30.txt" << std::endl; + return; + } std::vector numbers; std::string line; while (std::getline(input, line, ',')) { @@ -1513,6 +1604,8 @@ TEST_F(UniqueTest, DoUniqueLotsDataPaddingFunction) unique->UnInitialize(); ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 11); + delete[] idCnt; + delete[] uniqueIdInBucket; if (path) { free(path); } diff --git a/src/AccCTR/tests/ut/src/unique_test.h b/src/AccCTR/tests/ut/src/unique_test.h index 0243f262c9b19c52c4e71e1688d26f8fe3342d0e..c3bc64f3e37844d70d2fd33f86c684e3732f4130 100644 --- a/src/AccCTR/tests/ut/src/unique_test.h +++ b/src/AccCTR/tests/ut/src/unique_test.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include "factory.h" #include "gtest/gtest.h" #include "gmock/gmock.h" #include "unique.h" @@ -28,21 +27,6 @@ using namespace std; using namespace ock::ctr; -class SimpleThreadPool { -public: - static void SyncRun(const std::vector> &tasks) - { - std::vector> futs; - for (auto &task : tasks) { - futs.push_back(std::async(task)); - } - for (auto &fut : futs) { - fut.wait(); - } - } -}; - - class UniqueTest : public testing::Test { protected: UniqueTest() {}; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 84505d150a47d11d473649a0aa0190288ba59638..a5cd76da7965308c6afe9c307153a1653507639e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -56,7 +56,7 @@ else () message("==EASY_PROFILER_FOUND===") ADD_DEFINITIONS(-DBUILD_WITH_EASY_PROFILER) endif () -set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") +set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb -fsanitize=address -fsanitize-recover=address,all -fno-omit-frame-pointer -fno-stack-protector") set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wfatal-errors -DNDEBUG -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -s") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index dd1052f25659b5880ae561dac177e8b529edc0c8..64a076b9a9f22ccafbc01d365d37d00a2df7a6eb 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -27,6 +27,11 @@ if(NOT SECUREC_PATH) endif() message("SECUREC_PATH: " ${SECUREC_PATH}) +if(NOT ACCCTR_PATH) + set(ACCCTR_PATH ${PROJECT_SOURCE_DIR}/AccCTR) +endif() +message("ACCCTR_PATH: " ${ACCCTR_PATH}) + include_directories(${ABSEIL_PATH}/include) link_directories(${ABSEIL_PATH}/lib) @@ -38,7 +43,7 @@ endif() link_libraries(stdc++fs) -file(GLOB_RECURSE MXREC_SRC ./*.cpp) +file(GLOB_RECURSE MXREC_SRC ./*.cpp ./*.h) add_library(ASC SHARED ${MXREC_SRC}) target_include_directories(ASC @@ -55,10 +60,11 @@ target_link_directories(ASC ${HDF5_PATH}/lib ${SECUREC_PATH}/lib ${ASCEND_DRIVER_PATH}/lib64/driver + ${ACCCTR_PATH}/output/ock_ctr_common/lib ) target_link_libraries(ASC PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf - profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host) + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host _ock_ctr_common) target_link_libraries(ASC PUBLIC -l:_tf_adapter.so OpenMP::OpenMP_CXX ${MPI_CXX_LIBRARIES} diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 673c7ce386f5944ab34752940ec01c18488cbfff..469e209e9d4af7f1e7716215679a23bdbe8edaf3 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include +#include "checkpoint.h" + #include #include +#include +#include +#include + +#include #include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" -#include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" #include "ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h" -#include "utils/time_cost.h" -#include "utils/common.h" +#include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" #include "file_system/file_system_handler.h" - -#include "checkpoint.h" +#include "utils/common.h" +#include "utils/time_cost.h" using namespace std; using namespace MxRec; @@ -89,11 +90,18 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) void Checkpoint::SetDataHandler(const vector& featureTypes) { - map> setCkptMap{ - {CkptFeatureType::FEAT_ADMIT_N_EVICT, [this] { dataHandlers.push_back(make_unique()); }}, - {CkptFeatureType::DDR_KEY_FREQ_MAP, [this] { dataHandlers.push_back(make_unique()); }}, - {CkptFeatureType::KEY_COUNT_MAP, [this] { dataHandlers.push_back(make_unique()); }} + auto featAdmitNEvictHandler = [this] { + dataHandlers.push_back(make_unique()); }; + auto ddrKeyFreqMapHandler = [this] { + dataHandlers.push_back(make_unique()); + }; + auto keyCountMapHandler = [this] { + dataHandlers.push_back(make_unique()); + }; + map> setCkptMap{{CkptFeatureType::FEAT_ADMIT_N_EVICT, featAdmitNEvictHandler}, + {CkptFeatureType::DDR_KEY_FREQ_MAP, ddrKeyFreqMapHandler}, + {CkptFeatureType::KEY_COUNT_MAP, keyCountMapHandler}}; for (const auto& featureType : featureTypes) { setCkptMap.at(featureType)(); @@ -104,8 +112,8 @@ void Checkpoint::SaveProcess(CkptData& ckptData) { for (const auto& dataHandler : dataHandlers) { dataHandler->SetProcessData(ckptData); - vector embNames { dataHandler->GetEmbNames() }; - vector saveDataTypes { dataHandler->GetDataTypes() }; + vector embNames{dataHandler->GetEmbNames()}; + vector saveDataTypes{dataHandler->GetDataTypes()}; MakeUpperLayerSaveDir(); MakeDataLayerSaveDir(embNames, saveDataTypes, dataHandler); SaveDataset(embNames, saveDataTypes, dataHandler); @@ -118,17 +126,16 @@ void Checkpoint::MakeUpperLayerSaveDir() MakeSaveDir(innerDirPath); } -void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, - const vector& saveDataTypes, +void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler) { for (const auto& embName : embNames) { - auto dataDir { innerDirPath + dirSeparator + embName }; + auto dataDir{innerDirPath + dirSeparator + embName}; MakeSaveDir(dataDir); for (const auto& saveDataType : saveDataTypes) { - auto dataDirName { dataHandler->GetDataDirName(saveDataType) }; - auto datasetPath { dataDir + dirSeparator + dataDirName }; + auto dataDirName{dataHandler->GetDataDirName(saveDataType)}; + auto datasetPath{dataDir + dirSeparator + dataDirName}; MakeSaveDir(datasetPath); } } @@ -146,7 +153,7 @@ void Checkpoint::MakeSaveDir(const string& dirName) const Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) { EmbSizeInfo embSizeInfo; - for (const auto &embInfo: mgmtEmbInfo) { + for (const auto& embInfo : mgmtEmbInfo) { if (embInfo.name == embName) { embSizeInfo.embSize = embInfo.embeddingSize; embSizeInfo.extEmbSize = embInfo.extEmbeddingSize; @@ -158,29 +165,28 @@ Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) bool Checkpoint::CheckEmbNames(const string& embName) { - for (const auto &embInfo: mgmtEmbInfo) { - if (embInfo.name == embName && embInfo.isSave) { + for (const auto& embInfo : mgmtEmbInfo) { + if (embInfo.name == embName && embInfo.isSave) { return true; } } return false; } -void Checkpoint::SaveDataset(const vector& embNames, - const vector& saveDataTypes, +void Checkpoint::SaveDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler) { - for (const auto& embName: embNames) { + for (const auto& embName : embNames) { if (!CheckEmbNames(embName)) { continue; } auto dataDir{innerDirPath + dirSeparator + embName}; - for (const auto& saveDataType: saveDataTypes) { - auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; - auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + for (const auto& saveDataType : saveDataTypes) { + auto datasetPath{dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType)}; + auto datasetDir{datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType}; LOG_DEBUG("====Start getting data from handler to: {}", datasetDir); - auto transData { dataHandler->GetDataset(saveDataType, embName) }; + auto transData{dataHandler->GetDataset(saveDataType, embName)}; LOG_DEBUG("====Start saving data to: {}", datasetDir); WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); @@ -196,32 +202,37 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si } ssize_t writeBytesNum; - if (floatTransSet.find(dataType) != floatTransSet.end()) { - writeBytesNum = fileSystemPtr->Write(dataDir, transData.floatArr, dataSize); - } else if (int32TransSet.find(dataType) != int32TransSet.end()) { - writeBytesNum = fileSystemPtr->Write(dataDir, - reinterpret_cast(transData.int32Arr.data()), dataSize); + if (int32TransSet.find(dataType) != int32TransSet.end()) { + writeBytesNum = + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int32Arr.data()), dataSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - writeBytesNum = fileSystemPtr->Write(dataDir, - reinterpret_cast(transData.int64Arr.data()), dataSize); + writeBytesNum = + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int64Arr.data()), dataSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - writeBytesNum = fileSystemPtr->Write(dataDir, - reinterpret_cast(transData.attribute.data()), dataSize); + writeBytesNum = + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.attribute.data()), dataSize); + } else { + throw runtime_error("unknown CkptDataType"); } if (writeBytesNum == -1) { - LOG_ERROR("error happened when writing data to file."); - throw runtime_error("error happened when writing data to file."); + throw runtime_error(StringFormat("Error: Save data failed. data type: %s. " + "An error occurred while writing file: %s.", + CkptDataTypeName(dataType).c_str(), dataDir.c_str())); + } + if (writeBytesNum != dataSize) { + throw runtime_error(StringFormat("Error: Save data failed. data type: %s. " + "Expected to write %d bytes, but actually write %d bytes to file %s.", + CkptDataTypeName(dataType).c_str(), dataSize, writeBytesNum, dataDir.c_str())); } } - void Checkpoint::LoadProcess(CkptData& ckptData) { for (const auto& dataHandler : dataHandlers) { - vector embNames {}; - vector dirNames { dataHandler->GetDirNames() }; - vector saveDataTypes { dataHandler->GetDataTypes() }; + vector embNames{}; + vector dirNames{dataHandler->GetDirNames()}; + vector saveDataTypes{dataHandler->GetDataTypes()}; innerDirPath = processPath; if (find(dirNames.begin(), dirNames.end(), ssdSymbol) != dirNames.end()) { embNames = GetTableLayerLoadDir(); @@ -233,7 +244,6 @@ void Checkpoint::LoadProcess(CkptData& ckptData) } } - vector Checkpoint::GetEmbedTableNames() { vector loadTableNames; @@ -257,23 +267,20 @@ vector Checkpoint::GetTableLayerLoadDir() return loadTableDir; } -void Checkpoint::LoadDataset(const vector& embNames, - const vector& saveDataTypes, - const unique_ptr& dataHandler, - CkptData& ckptData) +void Checkpoint::LoadDataset(const vector& embNames, const vector& saveDataTypes, + const unique_ptr& dataHandler, CkptData& ckptData) { for (const auto& embName : embNames) { - auto dataDir { innerDirPath + dirSeparator + embName }; + auto dataDir{innerDirPath + dirSeparator + embName}; for (const auto& saveDataType : saveDataTypes) { - auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; + auto datasetPath{dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType)}; - auto datasetDir { datasetPath + dirSeparator + "slice" + dataFileType }; - auto attributeDir { datasetPath + dirSeparator + "slice" + attribFileType }; + auto datasetDir{datasetPath + dirSeparator + "slice" + dataFileType}; + auto attributeDir{datasetPath + dirSeparator + "slice" + attribFileType}; CkptTransData transData; - LOG_DEBUG("====Start reading data from: {}", attributeDir); - auto dataElmtBytes { dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE) }; + auto dataElmtBytes{dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE)}; ReadStream(transData, attributeDir, CkptDataType::ATTRIBUTE, dataElmtBytes); dataElmtBytes = dataHandler->GetDataElmtBytes(saveDataType); @@ -286,7 +293,7 @@ void Checkpoint::LoadDataset(const vector& embNames, } LOG_DEBUG("====Start loading data from: {} to data handler.", attributeDir); - if ((saveDataType == CkptDataType::EMB_INFO)) { + if ((saveDataType == CkptDataType::EMB_INFO)) { dataHandler->SetDatasetForLoadEmb(saveDataType, embName, transData, ckptData); } else { dataHandler->SetDataset(saveDataType, embName, transData); @@ -295,14 +302,12 @@ void Checkpoint::LoadDataset(const vector& embNames, } } -void Checkpoint::ReadStream(CkptTransData& transData, - const string& dataDir, - CkptDataType dataType, +void Checkpoint::ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes) { if (dataElmtBytes == 0) { LOG_WARN("dataElmtBytes is 0, don't handle [/ %] operation"); - return ; + return; } if (fileSystemPtr == nullptr) { @@ -311,7 +316,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, } size_t datasetSize = fileSystemPtr->GetFileSize(dataDir); - auto resizeSize { datasetSize / dataElmtBytes }; + auto resizeSize{datasetSize / dataElmtBytes}; SetTransDataSize(transData, resizeSize, dataType); if (datasetSize % dataElmtBytes > 0) { @@ -323,27 +328,31 @@ void Checkpoint::ReadStream(CkptTransData& transData, readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int32Arr.data()), datasetSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int64Arr.data()), datasetSize); - } else if (floatTransSet.find(dataType) != floatTransSet.end()) { - readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.floatArr.data()), datasetSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); + readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); + } else { + throw runtime_error("unknown CkptDataType"); } if (readBytesNum == -1) { - LOG_ERROR("error happened when reading data from file."); - throw runtime_error("error happened when reading data from file."); + throw runtime_error(StringFormat("Error: Load data failed. data type: %s. " + "An error occurred while reading file: %s.", + CkptDataTypeName(dataType).c_str(), dataDir.c_str())); + } + if (readBytesNum != datasetSize) { + throw runtime_error(StringFormat("Error: Load data failed. data type: %s. " + "Expected to read %d bytes, but actually read %d bytes to file %s.", + CkptDataTypeName(dataType).c_str(), datasetSize, readBytesNum, + dataDir.c_str())); } } -void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, - const string& dataDir, - uint32_t dataElmtBytes, - CkptData& ckptData, - string embName) const +void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, + CkptData& ckptData, string embName) const { if (dataElmtBytes == 0) { LOG_ERROR("dataElmtBytes is 0, don't handle [/ %] operation"); - return ; + return; } if (fileSystemPtr == nullptr) { @@ -373,9 +382,9 @@ void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, transData.int32Arr.resize(datasetSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { transData.int64Arr.resize(datasetSize); - } else if (floatTransSet.find(dataType) != floatTransSet.end()) { - transData.floatArr.resize(datasetSize); } else if (dataType == CkptDataType::ATTRIBUTE) { transData.attribute.resize(datasetSize); + } else { + throw runtime_error("unknown CkptDataType"); } } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 362881b255dd02202cf78cc6a4f902fa9cffcb4d..625660ff9d4f9a25535708f820f861a1b5952520 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -63,9 +63,6 @@ namespace MxRec { CkptDataType::KEY_COUNT_MAP, CkptDataType::EVICT_POS }; - const set floatTransSet{ - CkptDataType::EMB_DATA - }; vector> dataHandlers; string processPath; diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index 18f1a0903724a3321275c09519e18ec82cc7ba7f..04feb4b3d281cc7ee9bfca011ac9c3f8f88f9856 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -33,7 +33,6 @@ void CkptDataHandler::CleanTransfer() { transferData.int64Arr.clear(); transferData.int32Arr.clear(); - transferData.floatArr.clear(); transferData.attribute.clear(); transferData.datasetSize = 0; transferData.attributeSize = 0; @@ -42,7 +41,7 @@ void CkptDataHandler::CleanTransfer() void CkptDataHandler::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData) { - LOG_ERROR("Load host emb failed. dataType:{}, embName:{}, loadedData:{}, ckptData:{}", - dataType, embName, loadedData.datasetSize, ckptData.embHashMaps.empty()); + LOG_ERROR("Load host emb failed. dataType:{}, embName:{}, loadedData:{}", + dataType, embName, loadedData.datasetSize); throw runtime_error("only EMB_INFO and EMB_DATA supported for load host emb"); } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index 383317d9751312b4d38acae09be074f8431243e1..0ca33294a2e2246d73c935e650cae7cc7811a7e7 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -18,8 +18,6 @@ See the License for the specific language governing permissions and #include -#include "emb_hashmap/emb_hashmap.h" -#include "host_emb/host_emb.h" #include "utils/common.h" namespace MxRec { diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index be35044bd991a91f3e90496e778229289e1d4872..140b9c773afbfc4491d287438940b5670fd0c5d0 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -157,7 +157,7 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { process = i % printPerStep; if (process == 1) { - LOG_DEBUG("====in SetHistRec, process : %f", i / featItemInfoTotalSize); + LOG_TRACE("====in SetHistRec, process : {}", i / featItemInfoTotalSize); } auto featureId = transArr[i + featureIdIdxOffset]; auto count = transArr[i + countIdxOffset]; diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp deleted file mode 100644 index 977b2c0b5fb4dbbfa92741302b9b0f0d42db9768..0000000000000000000000000000000000000000 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ /dev/null @@ -1,477 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include "emb_hashmap.h" -#include -#include - -#include "hybrid_mgmt/hybrid_mgmt_block.h" -#include "utils/common.h" -#include "emb_table/embedding_mgmt.h" - -using namespace MxRec; - -void EmbHashMap::Init(const RankInfo& ri, const vector& embInfos, bool ifLoad) -{ - this->rankInfo = ri; - if (!ifLoad) { - EmbHashMapInfo embHashMapInfo; - LOG_INFO("init emb hash map from scratch"); - for (const auto& embInfo: embInfos) { - embHashMapInfo.devOffset2Batch.resize(embInfo.devVocabSize); - embHashMapInfo.devOffset2Key.resize(embInfo.devVocabSize); - embHashMapInfo.hostVocabSize = embInfo.hostVocabSize; - embHashMapInfo.devVocabSize = embInfo.devVocabSize; - embHashMapInfo.currentUpdatePos = 0; - fill(embHashMapInfo.devOffset2Batch.begin(), embHashMapInfo.devOffset2Batch.end(), -1); - fill(embHashMapInfo.devOffset2Key.begin(), embHashMapInfo.devOffset2Key.end(), -1); - embHashMaps[embInfo.name] = embHashMapInfo; - - LOG_TRACE("devOffset2Key, {}", VectorToString(embHashMaps.at(embInfo.name).devOffset2Key)); - LOG_TRACE("devOffset2Batch, {}", VectorToString(embHashMaps.at(embInfo.name).devOffset2Batch)); - } - } -} - -void EmbHashMap::ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) const -{ - embHashMap.swapPos.clear(); - embHashMap.lookUpVec.clear(); - embHashMap.ddr2HbmKeys.clear(); -} - -/// DDR模型下处理特征的offset、swap信息等 -/// \param embName 表名 -/// \param keys 查询向量 -/// \param DDRParam 临时向量 -/// \param channelId 通道索引(训练/推理) -void EmbHashMap::Process(const string& embName, vector& keys, DDRParam& ddrParam, int channelId) -{ -#ifndef GTEST - EASY_FUNCTION(profiler::colors::Pink) - TimeCost swapTimeCost; - std::shared_ptr table = EmbeddingMgmt::Instance()->GetTable(embName); - - int32_t keepBatch = swapId; // 处理batch的次数,多个预取一起处理算一次 - vector swapPos; - vector lookUpVec = table->FindOffset(keys, swapId, channelId, swapPos); - - table->RefreshFreqInfoWithSwap(); - - EASY_BLOCK("hostHashMaps->tdt") - - std::copy(lookUpVec.begin(), lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); - - // 构造查询向量tensor - int lookUpVecSize = static_cast(lookUpVec.size()); - ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); - - auto lookupTensorData = ddrParam.tmpDataOut.back().flat(); - for (int i = 0; i < lookUpVecSize; i++) { - lookupTensorData(i) = static_cast(lookUpVec[i]); - } - LOG_TRACE("lookupTensor, {}", VectorToString(lookUpVec)); - - // 构造交换向量tensor - int swapSize = static_cast(swapPos.size()); - ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); - - auto swapTensorData = ddrParam.tmpDataOut.back().flat(); - for (int i = 0; i < swapSize; i++) { - swapTensorData(i) = static_cast(swapPos[i]); - } - if (swapSize > 0) { - LOG_DEBUG("swap num: {}", swapSize); - } - - LOG_TRACE("swapTensor, {}", VectorToString(swapPos)); - // 清空本次记录的查询偏移和交换偏移 - table->ClearLookupAndSwapOffset(); - - LOG_INFO("current ddr emb:{}, usage:{}/[{}+{}]", embName, table->GetMaxOffset(), - table->GetDevVocabSize(), table->GetHostVocabSize()); - - ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); - auto swapLen = ddrParam.tmpDataOut.back().flat(); - swapLen(0) = swapSize; - - if (GlogConfig::gStatOn) { - LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} swap_key_size {} swap_time_cost {}", - channelId, swapId, rankInfo.rankId, swapSize, swapTimeCost.ElapsedMS()); - } - - swapId++; - EASY_END_BLOCK -#endif -} - - -auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map -{ - LOG_DEBUG(HYBRID_BLOCKING + " start GetHashMaps"); - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - auto embHashMapsOld = embHashMaps; - int checkResult = hybridMgmtBlock->CheckSaveEmbMapValid(); - if (checkResult == 0) { - // 检查是否需要回退 - return embHashMapsOld; - } - if (checkResult == 1) { - // 回退一步 - for (auto& temp: embHashMapsOld) { - auto &embHashMap = temp.second; - for (auto &swapKeys: embHashMap.oldSwap) { - emb_key_t oldKey = swapKeys.first; - emb_key_t key = swapKeys.second; - int tempOffset = static_cast(embHashMap.hostHashMap[key]); - embHashMap.hostHashMap[key] = embHashMap.hostHashMap[oldKey]; - embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); - } - embHashMap.maxOffset = embHashMap.maxOffsetOld; - for (auto &offset2Key: embHashMap.devOffset2KeyOld) { - embHashMap.devOffset2Key[offset2Key.first] = offset2Key.second; - } - } - return embHashMapsOld; - } - // 此时需要回退2步,无法满足此条件,保存的东西错误,直接回退 - if (rankInfo.isDDR) { - throw HybridMgmtBlockingException("EmbHashMap::GetHashMaps() "); - } - return embHashMapsOld; -} - -void EmbHashMap::LoadHashMap(EmbHashMemT& loadData) -{ - embHashMaps = std::move(loadData); -} - -/// 对HBM剩余空间和更新位置进行初始化 -void EmbHashMapInfo::SetStartCount() -{ - currentUpdatePosStart = currentUpdatePos; - freeSize = devVocabSize; -} - -/// 判断HBM是否有剩余空间 -/// \param i 查询向量的大小 -/// \return -bool EmbHashMapInfo::HasFree(size_t i) const -{ - return freeSize < i; -} - -/* -* 删除淘汰key的映射关系,并将其offset更新到evictPos,待后续复用 -*/ -void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& keys) -{ - EASY_FUNCTION() - size_t keySize = keys.size(); - auto& embHashMap = embHashMaps.at(embName); - vector evictHBMKeys; - vector evictDDRKeys; - for (size_t i = 0; i < keySize; i++) { - size_t offset; - auto key = keys[i]; - if (key == -1) { - LOG_WARN("evict key equal -1!"); - continue; - } - const auto& iter = embHashMap.hostHashMap.find(key); - if (iter != embHashMap.hostHashMap.end()) { - offset = iter->second; - embHashMap.hostHashMap.erase(iter); - LOG_TRACE("evict embName {}, offset {}", embName, offset); - } else { - // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 - continue; - } - - if (offset < embHashMap.devVocabSize) { - embHashMap.devOffset2Batch[offset] = -1; - embHashMap.devOffset2KeyOld.emplace_back(offset, embHashMap.devOffset2Key[offset]); - embHashMap.devOffset2Key[offset] = -1; - embHashMap.evictDevPos.emplace_back(offset); - evictHBMKeys.emplace_back(key); - } else { - embHashMap.evictPos.emplace_back(offset); - evictDDRKeys.emplace_back(key); - } - } - if (isSSDEnabled) { - cacheManager->RefreshFreqInfoCommon(embName, evictHBMKeys, TransferType::HBM_2_EVICT); - cacheManager->RefreshFreqInfoCommon(embName, evictDDRKeys, TransferType::DDR_2_EVICT); - } - - LOG_INFO("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {}", - embName, embHashMap.evictPos.size(), embHashMap.evictDevPos.size()); - LOG_TRACE("hostHashMap, {}", MapToString(embHashMaps[embName].hostHashMap)); -} - - -/// 从embHashMaps获取key对应的位置,构造查询向量;更新devOffset2Batch;记录dev与host需要交换的偏移 -/// \param embName 表名 -/// \param keys 查询向量 -/// \param currentBatchId 已处理的batch数 -/// \param keepBatchId 处理batch的次数,多个预取一起处理算一次 -/// \param channelId 通道索引(训练/推理) -void EmbHashMap::FindOffset(const string& embName, const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channelId) -{ - EASY_FUNCTION() - size_t keySize = keys.size(); - auto it = embHashMaps.find(embName); - if (it == embHashMaps.end()) { - throw runtime_error("table not exist in embHashMaps"); - } - auto &embHashMap = it->second; - UpdateBatchId(keys, currentBatchId, keySize, embHashMap); - for (size_t i = 0; i < keySize; i++) { - auto key = keys[i]; - if (key == -1) { - embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); - continue; - } - size_t offset; - auto isOffsetValid = FindOffsetHelper(key, embHashMap, channelId, offset); - if (!isOffsetValid) { - embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); - continue; - } - AddKeyFreqInfo(embName, key, RecordType::NOT_DDR); - if (offset < embHashMap.devVocabSize) { - // 偏移小于等于HBM容量:直接放入查询向量;更新偏移之前关联的key和当前关联的key - embHashMap.lookUpVec.emplace_back(offset); - embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); - embHashMap.devOffset2Key[offset] = key; - } else { - // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 - embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); - FindSwapPosOld(embName, key, offset, currentBatchId, keepBatchId); - } - } - if (currentBatchId == 0) { - LOG_INFO("max offset {}", embHashMap.maxOffset); - } - LOG_TRACE("hostHashMap, {}", MapToString(embHashMaps[embName].hostHashMap)); -} - - -/// 查找key对应的偏移;1. 已在hash map中,直接返回对应的offset;2. 开启淘汰的情况下,复用淘汰的位置;3. 没有则新分配 -/// \param key 输入特征 -/// \param embHashMap hash map实例 -/// \param channelId 通道索引(训练/推理) -/// \param offset 未初始化变量,用于记录 -/// \return -bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) const - -{ - const auto& iter = embHashMap.hostHashMap.find(key); - if (iter != embHashMap.hostHashMap.end()) { - offset = iter->second; - LOG_TRACE("devVocabSize, {} , offset , {}", embHashMap.devVocabSize, offset); - if (isSSDEnabled && offset >= embHashMap.devVocabSize) { - embHashMap.ddr2HbmKeys.emplace_back(key); - } - } else if (embHashMap.evictDevPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // 优先复用hbm表 - offset = embHashMap.evictDevPos.back(); - embHashMap.hostHashMap[key] = offset; - LOG_TRACE("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictDevPos.size()); - embHashMap.evictDevPos.pop_back(); - } else if (embHashMap.evictPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // hbm不足,再复用ddr表 - offset = embHashMap.evictPos.back(); - embHashMap.hostHashMap[key] = offset; - LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictPos.size()); - embHashMap.evictPos.pop_back(); - } else { - if (channelId == TRAIN_CHANNEL_ID) { - embHashMap.hostHashMap[key] = embHashMap.maxOffset; - offset = embHashMap.maxOffset; - embHashMap.maxOffset++; - if (embHashMap.maxOffset == embHashMap.devVocabSize) { - LOG_INFO("start using host vocab!"); - } - if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - LOG_ERROR("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, embHashMap.hostVocabSize); - throw runtime_error("hostVocabSize too small"); - } - } else { - return false; - } - } - return true; -} - -/// 更新HBM中的key相应offset最近出现的batch步数,用于跟踪哪些offset是最近在使用的 -/// \param keys 查询向量 -/// \param currentBatchId 已处理的batch数 -/// \param keySize 查询向量长度 -/// \param embHashMap hash map实例 -void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, - EmbHashMapInfo& embHashMap) const -{ - for (size_t i = 0; i < keySize; i++) { - size_t offset; - auto key = keys[i]; - if (key == -1) { - continue; - } - const auto& iter = embHashMap.hostHashMap.find(key); - if (iter != embHashMap.hostHashMap.end()) { - offset = iter->second; - - LOG_TRACE("key will be used, {} , offset , {}", key, offset); - if (offset < embHashMap.devVocabSize) { - // devOffset2Batch size equal to devVocabSize, unnecessary to check index boundary - embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); - } - } - } -} - -/// 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys -/// \param embName 表名 -/// \param key 输入特征 -/// \param hostOffset 全局偏移 -/// \param currentBatchId 已处理的batch数 -/// \param keepBatchId 处理batch的次数,多个预取一起处理算一次 -/// \return 是否找到需要交换的位置 -bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, - size_t keepBatchId) -{ - bool notFind = true; - auto it = embHashMaps.find(embName); - if (it == embHashMaps.end()) { - throw runtime_error("table not exist in embHashMaps"); - } - auto &embHashMap = it->second; - while (notFind) { - // 找到本次预取之前的偏移(保证所有预取batch的key都在HBM中) - if (embHashMap.currentUpdatePos >= embHashMap.devOffset2Batch.size()) { - throw runtime_error("currentUpdatePos out of range"); - } - - if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { - embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); - embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); // 记录需要被换出的HBM偏移 - embHashMap.lookUpVec.emplace_back(embHashMap.currentUpdatePos); // 交换的位置就是该key查询的偏移 - embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; // 更新key对应的HBM偏移 - // 记录HBM偏移之前的key - embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, - embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); - auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; - embHashMap.oldSwap.emplace_back(oldKey, key); // 记录交换的两个key oldKey:HBM->DDR key:DDR->HBM - embHashMap.hostHashMap[oldKey] = hostOffset; // 更新被替换的key的偏移 - oldKey = key; - notFind = false; - } - embHashMap.currentUpdatePos++; // 查找位置+1 - embHashMap.freeSize--; // HBM可用空间-1 - - // 遍历完一遍整个HBM表后,从头开始遍历 - if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { - embHashMap.currentUpdatePos = 0; - } - - // 已经找完整个HBM空间,且没找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key,无法正常执行训练,故运行时错误退出 - if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart && notFind) { - LOG_ERROR("devVocabSize is too small"); - throw runtime_error("devVocabSize is too small"); - } - } - return true; -} - -/// HBM-DDR换入换出时刷新频次信息 -/// \param embName emb表名 -/// \param embHashMap emb hash map -void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) const -{ - if (!isSSDEnabled) { - return; - } - // 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key - auto& oldSwap = embHashMap.oldSwap; - LOG_DEBUG("RefreshFreqInfoWithSwap:oldSwap Size:{}", oldSwap.size()); - vector enterDDRKeys; - for (auto keyPair : oldSwap) { - enterDDRKeys.emplace_back(keyPair.first); - } - cacheManager->RefreshFreqInfoCommon(embName, enterDDRKeys, TransferType::HBM_2_DDR); - cacheManager->RefreshFreqInfoCommon(embName, embHashMap.ddr2HbmKeys, TransferType::DDR_2_HBM); - - AddCacheManagerTraceLog(embName, embHashMap); -} - -/// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 -void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const -{ - if (Logger::GetLevel() != Logger::TRACE) { - return; - } - auto& hostMap = embHashMap.hostHashMap; - auto& devSize = embHashMap.devVocabSize; - auto iter = cacheManager->ddrKeyFreqMap.find(embTableName); - if (iter == cacheManager->ddrKeyFreqMap.end()) { - throw runtime_error("table not in ddrKeyFreqMap"); - } - auto &lfu = iter->second; - const auto& lfuTab = lfu.GetFreqTable(); - if (lfuTab.empty()) { - return; - } - size_t tableKeyInDdr = 0; - vector ddrKeys; // 获取hostHashMap中保存在DDR的key - for (const auto& item : hostMap) { - if (item.second < devSize) { - continue; - } - ddrKeys.emplace_back(item.first); - ++tableKeyInDdr; - } - vector lfuKeys; - for (const auto& it : lfuTab) { - lfuKeys.emplace_back(it.first); - } - std::sort(ddrKeys.begin(), ddrKeys.end()); - std::sort(lfuKeys.begin(), lfuKeys.end()); - std::string ddrKeysString = VectorToString(ddrKeys); - std::string lfuKeysString = VectorToString(lfuKeys); - if (ddrKeysString != lfuKeysString) { - LOG_ERROR("swap HBM with DDR step error, key string not equal, ddrKeysString:{}, lfuKeysString:{}", - ddrKeysString, lfuKeysString); - } else { - LOG_INFO("swap HBM with DDR step OK, table:{}, ddrKeysString == lfuKeysString, string length:{}", - embTableName, lfuKeysString.length()); - } - - LOG_INFO("swap HBM with DDR step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", - embTableName, tableKeyInDdr, lfu.keyTable.size()); -} - -/// 记录key频次数据 -/// \param embTableName emb表名 -/// \param key key -/// \param type 记录类型枚举 -void EmbHashMap::AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type) const -{ - if (!isSSDEnabled) { - return; - } - cacheManager->PutKey(embTableName, key, type); -} diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h deleted file mode 100644 index 96a75e54e2ce8d24d0faef5598d352d893d7f156..0000000000000000000000000000000000000000 --- a/src/core/emb_hashmap/emb_hashmap.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#ifndef MX_REC_EMB_HASHMAP_H -#define MX_REC_EMB_HASHMAP_H - -#include -#include -#include -#include "absl/container/flat_hash_map.h" -#include "host_emb/host_emb.h" -#include "ssd_cache/cache_manager.h" -#include "utils/common.h" -#include "utils/time_cost.h" - -namespace MxRec { - using namespace std; - - class EmbHashMap { - public: - EmbHashMap() = default; - - void Init(const RankInfo& ri, const vector& embInfos, bool ifLoad = false); - - void Process(const string& embName, std::vector& keys, DDRParam& ddrParam, int channelId); - - auto GetHashMaps() -> absl::flat_hash_map; - - void LoadHashMap(absl::flat_hash_map& loadData); - - void EvictDeleteEmb(const string& embName, const vector& keys); - - absl::flat_hash_map embHashMaps; - - bool FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) const; - - void UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, - EmbHashMapInfo& embHashMap) const; - - bool FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, - size_t keepBatchId); - - std::vector& GetEvictPos(const string& embName) - { - return embHashMaps.at(embName).evictPos; - } - - bool isSSDEnabled { false }; - CacheManager* cacheManager; - - GTEST_PRIVATE: - - void FindOffset(const string& embName, const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channelId); - - void AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const; - - void AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type) const; - - void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) const; - - void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) const; - - RankInfo rankInfo; - int swapId { 0 }; - }; -} - -#endif // MX_REC_EMB_HASHMAP_H diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp deleted file mode 100644 index 1c24eb2b1f215362b4a670d4ea4e6a4104924bd8..0000000000000000000000000000000000000000 --- a/src/core/emb_table/emb_table.cpp +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include "acl/acl_base.h" -#include "utils/common.h" -#include "initializer/initializer.h" -#include "emb_table/emb_table.h" - - -using namespace std; -using namespace MxRec; -using namespace tensorflow; - -void EmbTable::Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed) -{ -#ifndef GTEST - this->rankInfo = rInfo; - this->seed = initSeed; - this->embInfo = eInfo; - LOG_INFO("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); - // 计算embedding table需要分配的内存块数 - auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); - if (ret != ACL_ERROR_NONE) { - LOG_ERROR("Set device failed, device_id:{}, ret={}", rInfo.deviceId, ret); - throw AclError(); - } - embSize = embInfo.extEmbeddingSize; - blockSize = BLOCK_EMB_COUNT * embSize; - for (int i = 0; i < INIT_BLOCK_COUNT; ++i) { - // 申请新的内存块 - void *newBlock = nullptr; - aclError ec = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); - if (ec != ACL_SUCCESS) { - LOG_ERROR("aclrtMalloc failed, ret={}", ec); - throw AclError(); - } - // 申请内存初始化 - RandomInit(newBlock); - // 将新的内存块加入内存链表 - memoryList.push_back(newBlock); - SplitMemoryBlock(newBlock); - } - totalCapacity = static_cast(memoryList.size()) * BLOCK_EMB_COUNT; - LOG_INFO("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); -#endif -} - -EmbTable::~EmbTable() -{ -#ifndef GTEST - for (void *block : memoryList) { - // 释放内存块 - aclError ret = aclrtFree(block); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtFree failed, ret={}", ret); - } - block = nullptr; - } -#endif -} - -// 从embeddingList获取一个可用的emb地址 -int64_t EmbTable::GetEmbAddress() -{ -#ifndef GTEST - if (embeddingList.empty()) { - PrintStatus(); - LOG_DEBUG("GetEmbAddress, embedding_list size: empty! Add block!"); - void *addBlock = nullptr; - aclError ret = aclrtMalloc(&addBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMalloc failed, ret={}", ret); - throw AclError(); - } - RandomInit(addBlock); - // 将新的内存块加入内存list - memoryList.push_back(addBlock); - SplitMemoryBlock(addBlock); - totalCapacity += BLOCK_EMB_COUNT; - } - float *embAddr = embeddingList.front(); - embeddingList.pop_front(); - usedCapacity++; - return reinterpret_cast(embAddr); -#endif -} - -void EmbTable::RandomInit(void* newBlock) -{ -#ifndef GTEST - LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, embInfo.initializeInfos.size()); - vector devEmb(blockSize); - for (const auto& initializeInfo: as_const(embInfo.initializeInfos)) { - LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name.c_str()); - for (int i = 0; i < BLOCK_EMB_COUNT; i++) { - initializeInfo.initializer->GenerateData(&devEmb[i * embSize], embSize); - } - } - LOG_INFO("Device GenerateEmbData End, seed:{}", seed); - ExecuteAclMemcpy(newBlock, devEmb); -#endif -} - -void EmbTable::ExecuteAclMemcpy(void* newBlock, vector devEmb) const -{ -#ifndef GTEST - aclError ret = aclrtMemcpy( - newBlock, blockSize * sizeof(float), devEmb.data(), blockSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMemcpy failed, ret={}", ret); - throw AclError(); - } -#endif -} - - -void EmbTable::SplitMemoryBlock(void *newBlock) -{ -#ifndef GTEST - if (embSize == 0) { - throw std::runtime_error("SplitMemoryBlock by embSize=0!"); - } - for (int i = 0; i < BLOCK_EMB_COUNT; i++) { - float *embPtr = static_cast(newBlock) + i * embSize; - embeddingList.push_back(embPtr); - } -#endif -} - -void EmbTable::PrintStatus() const -{ - // 输出embedding table的总容量和未使用的使用容量 - LOG_INFO("Total capacity:{}, Unused capacity:{}", - totalCapacity * embSize, totalCapacity * embSize - usedCapacity * embSize); -} - -int64_t EmbTable::GetTableSize() const -{ - return static_cast(usedCapacity); -} - -int64_t EmbTable::GetTableCapacity() const -{ - return static_cast(totalCapacity); -} diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp index 02d7c1164eb988a36b142e47ebe350faaf396419..d05b35019f9b9af0e019616dba27f1d0eccaa8de 100644 --- a/src/core/emb_table/embedding_ddr.cpp +++ b/src/core/emb_table/embedding_ddr.cpp @@ -12,46 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "emb_table/embedding_ddr.h" + #include + #include "utils/logger.h" #include "utils/singleton.h" -#include "host_emb/host_emb.h" -#include "file_system/file_system_handler.h" -#include "ssd_cache/cache_manager.h" -#include "emb_table/embedding_mgmt.h" +#include "l3_storage/cache_manager.h" +#include "ock_ctr_common/include/error_code.h" using namespace MxRec; -constexpr int ELEMENT_NUM = 4; -constexpr int CURRENT_UPDATE_IDX = 0; -constexpr int HOST_VOCAB_SIZE_IDX = 1; -constexpr int DEV_VOCAB_SIZE_IDX = 2; -constexpr int MAX_OFFSET_IDX = 3; - -constexpr int EMB_INFO_ELEMENT_NUM = 3; -constexpr int EMB_INFO_EXT_SIZE_IDX = 0; -constexpr int EMB_INFO_DEV_VOCAB_SIZE_IDX = 1; -constexpr int EMB_INFO_HOST_VOCAB_SIZE_IDX = 2; - EmbeddingDDR::EmbeddingDDR() { } EmbeddingDDR::EmbeddingDDR(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) - : EmbeddingTable(info, rankInfo, inSeed) + : EmbeddingTable(info, rankInfo, inSeed), deviceId(rankInfo.deviceId) { - LOG_INFO("Init DDR table [{}] devVocabSize = {} hostVocabSize = {}", name, devVocabSize, hostVocabSize); - currentUpdatePos = 0; - devOffset2Key.resize(devVocabSize); - devOffset2Batch.resize(devVocabSize); - std::fill(devOffset2Batch.begin(), devOffset2Batch.end(), -1); - std::fill(devOffset2Key.begin(), devOffset2Key.end(), -1); + LOG_INFO("Init DDR table:{}, devVocabSize:{}, hostVocabSize:{}", name, devVocabSize, hostVocabSize); } EmbeddingDDR::~EmbeddingDDR() { + hdTransfer = nullptr; + embCache = nullptr; } void EmbeddingDDR::Key2Offset(std::vector& splitKey, int channel) @@ -60,169 +45,7 @@ void EmbeddingDDR::Key2Offset(std::vector& splitKey, int channel) int64_t EmbeddingDDR::capacity() const { - return capacity_; -} - -std::vector EmbeddingDDR::FindOffset(const vector& keys, - size_t batchId, int channelId, - std::vector& swapPos) -{ - devOffset2KeyOld.clear(); - oldSwap.clear(); - maxOffsetOld = maxOffset; - - UpdateBatchId(keys, batchId); - std::vector lookUpVec; - for (size_t i = 0; i < keys.size(); i++) { - emb_key_t key = keys[i]; - if (key == INVALID_KEY_VALUE) { - lookUpVec.emplace_back(INVALID_KEY_VALUE); - continue; - } - emb_key_t offset = FindOffsetHelper(key, channelId); - if (offset == INVALID_KEY_VALUE) { - lookUpVec.emplace_back(INVALID_KEY_VALUE); - continue; - } - AddKeyFreqInfo(key, RecordType::NOT_DDR); - if (offset < devVocabSize) { - // 偏移小于等于HBM容量:直接放入查询向量;更新偏移之前关联的key和当前关联的key - lookUpVec.push_back(offset); - devOffset2KeyOld.emplace_back(offset, static_cast(devOffset2Key[offset])); - devOffset2Key[offset] = key; - } else { - // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 - missingKeysHostPos_.emplace_back(offset - devVocabSize); - offset = FindSwapPosOld(key, offset, batchId, swapPos); - lookUpVec.emplace_back(offset); - } - } - if (batchId == 0) { - LOG_INFO("max offset {}", maxOffset); - } - LOG_TRACE("keyOffsetMap, {}", MapToString(keyOffsetMap)); - return lookUpVec; -} - -emb_key_t EmbeddingDDR::FindOffsetHelper(const emb_key_t& key, int channelId) -{ - const auto& iter = keyOffsetMap.find(key); - emb_key_t offset = INVALID_KEY_VALUE; - if (iter != keyOffsetMap.end()) { - offset = iter->second; - LOG_TRACE("devVocabSize, {} , offset , {}", devVocabSize, offset); - if (isSSDEnabled_ && offset >= devVocabSize) { - ddr2HbmKeys.emplace_back(key); - } - return offset; - } - if (channelId != TRAIN_CHANNEL_ID) { - return offset; - } - if (evictDevPos.size() != 0) { // 优先复用hbm表 - offset = evictDevPos.back(); - keyOffsetMap[key] = offset; - LOG_TRACE("ddr mode, dev evictDevPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, evictDevPos.size()); - evictDevPos.pop_back(); - LOG_ERROR("dev evicted offset = {}", offset); - return offset; - } - - if (evictHostPos.size() != 0) { // hbm不足,再复用host/ddr表 - offset = evictHostPos.back(); - keyOffsetMap[key] = offset; - LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, evictHostPos.size()); - evictHostPos.pop_back(); - LOG_TRACE("host evicted offset = {}", offset); - return offset; - } - keyOffsetMap[key] = maxOffset; - offset = maxOffset; - maxOffset++; - if (maxOffset == devVocabSize) { - LOG_INFO("start using host vocab!"); - } - if (maxOffset > (hostVocabSize + devVocabSize)) { - LOG_ERROR("hostVocabSize too small! dev:{} host:{}", devVocabSize, hostVocabSize); - throw runtime_error("hostVocabSize too small"); - } - return offset; -} - -void EmbeddingDDR::UpdateBatchId(const vector& keys, size_t currentBatchId) -{ - for (size_t i = 0; i < keys.size(); i++) { - size_t offset; - emb_key_t key = keys[i]; - if (key == -1) { - continue; - } - const auto& iter = keyOffsetMap.find(key); - if (iter != keyOffsetMap.end()) { - offset = iter->second; - - LOG_TRACE("key will be used, {} , offset , {}", key, offset); - if (offset < devVocabSize) { - // devOffset2Batch size equal to devVocabSize, unnecessary to check index boundary - devOffset2Batch[offset] = static_cast(currentBatchId); - } - } - } -} - -/// 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys -/// \param embName 表名 -/// \param key 输入特征 -/// \param hostOffset 全局偏移 -/// \param currentBatchId 已处理的batch数 -/// \param keepBatchId 处理batch的次数,多个预取一起处理算一次 -/// \return 是否找到需要交换的位置 -emb_key_t EmbeddingDDR::FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t batchId, - std::vector& swapPos) -{ - bool notFind = true; - emb_key_t offset = INVALID_KEY_VALUE; - while (notFind) { - // 找到本次预取之前的偏移(保证所有预取batch的key都在HBM中) - if (currentUpdatePos >= devOffset2Batch.size()) { - LOG_ERROR("outofrange {} >= {}", currentUpdatePos, devOffset2Batch.size()); - throw runtime_error("currentUpdatePos out of range"); - } - - if (devOffset2Batch[currentUpdatePos] < static_cast(batchId)) { - devOffset2Batch[currentUpdatePos] = static_cast(batchId); - swapPos.emplace_back(currentUpdatePos); // 记录需要被换出的HBM偏移 - offset = currentUpdatePos; - keyOffsetMap[key] = currentUpdatePos; // 更新key对应的HBM偏移 - // 记录HBM偏移之前的key - devOffset2KeyOld.emplace_back(currentUpdatePos, devOffset2Key[currentUpdatePos]); - auto& oldKey = devOffset2Key[currentUpdatePos]; - oldSwap.emplace_back(oldKey, key); // 记录交换的两个key oldKey:HBM->DDR key:DDR->HBM - keyOffsetMap[oldKey] = hostOffset; // 更新被替换的key的偏移 - oldKey = key; - notFind = false; - } - currentUpdatePos++; // 查找位置+1 - freeSize_--; // HBM可用空间-1 - - // 遍历完一遍整个HBM表后,从头开始遍历 - if (currentUpdatePos == devVocabSize) { - currentUpdatePos = 0; - } - - /** - * currentUpdatePos已经绕了HBM一圈 - * 已经找完整个HBM空间,且没找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key, - * 无法正常执行训练,故运行时错误退出 - */ - if (currentUpdatePos == currentUpdatePosStart && notFind) { - LOG_ERROR("devVocabSize is too small"); - throw runtime_error("devVocabSize is too small"); - } - } - return offset; + return capacity_.load(); } /* @@ -230,47 +53,6 @@ emb_key_t EmbeddingDDR::FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t */ void EmbeddingDDR::EvictDeleteEmb(const vector& keys) { - EASY_FUNCTION() - size_t keySize = keys.size(); - vector evictHBMKeys; - vector evictDDRKeys; - for (size_t i = 0; i < keySize; ++i) { - size_t offset; - emb_key_t key = keys[i]; - if (key == INVALID_KEY_VALUE) { - LOG_WARN("evict key equal -1!"); - continue; - } - const auto& iter = keyOffsetMap.find(key); - if (iter == keyOffsetMap.end()) { - // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 - continue; - } - offset = iter->second; - keyOffsetMap.erase(iter); - LOG_TRACE("evict embName {}, offset {}", name, offset); - - if (offset < devVocabSize) { - // offset 在device中 - devOffset2Batch[offset] = -1; - devOffset2KeyOld.emplace_back(offset, devOffset2Key[offset]); - devOffset2Key[offset] = -1; - evictDevPos.emplace_back(offset); - evictHBMKeys.emplace_back(key); - } else { - // offset 在Host - evictHostPos.emplace_back(offset); - evictDDRKeys.emplace_back(key); // 删除映射表、初始化host表、发送dev淘汰位置 - } - } - if (isSSDEnabled_) { - cacheManager_->RefreshFreqInfoCommon(name, evictHBMKeys, TransferType::HBM_2_EVICT); - cacheManager_->RefreshFreqInfoCommon(name, evictDDRKeys, TransferType::DDR_2_EVICT); - } - - LOG_INFO("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {}", - name, evictHostPos.size(), evictDevPos.size()); - LOG_TRACE("keyOffsetMap, {}", MapToString(keyOffsetMap)); } /// DDR模式下的淘汰:删除映射表、初始化host表、发送dev淘汰位置 @@ -278,327 +60,304 @@ void EmbeddingDDR::EvictDeleteEmb(const vector& keys) /// \param keys void EmbeddingDDR::EvictKeys(const vector& keys) { - EASY_FUNCTION() - for (const emb_key_t& key : keys) { - size_t offset; - if (key == INVALID_KEY_VALUE) { - LOG_WARN("evict key equal -1!"); - continue; - } - const auto& iter = keyOffsetMap.find(key); - if (iter == keyOffsetMap.end()) { - continue; - } - // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 - offset = iter->second; - keyOffsetMap.erase(iter); - LOG_TRACE("evict embName {}, offset {}", name, offset); - - if (offset < devVocabSize) { - devOffset2Batch[offset] = INVALID_KEY_VALUE; - devOffset2KeyOld.emplace_back(offset, devOffset2Key[offset]); - devOffset2Key[offset] = INVALID_KEY_VALUE; - evictDevPos.emplace_back(offset); - } else { - evictHostPos.emplace_back(offset); - } - } } -void EmbeddingDDR::ClearLookupAndSwapOffset() +void EmbeddingDDR::Load(const string& savePath, map>& trainKeySet) { - ddr2HbmKeys.clear(); -} + vector keys; + vector> embeddings; + vector> optimizerSlots; -void EmbeddingDDR::SetStartCount() -{ - currentUpdatePosStart = currentUpdatePos; - freeSize_ = devVocabSize; -} + LoadKey(savePath, keys); + LoadEmbedding(savePath, embeddings); + LoadOptimizerSlot(savePath, optimizerSlots); -void EmbeddingDDR::Load(const string& savePath) -{ - int res = LoadHashMap(savePath); - if (res == -1) { - throw std::runtime_error("load key failed!"); + auto rc = embCache->LoadEmbTableInfos(name, keys, embeddings, optimizerSlots); + if (rc != 0) { + throw runtime_error("embCache->LoadEmbTableInfos failed, err code:" + to_string(rc)); } - LoadEmbAndOptim(savePath); -} -void EmbeddingDDR::Save(const string& savePath) -{ - SaveKey(savePath); - SaveEmbAndOptim(savePath); + trainKeySet[name].insert(keys.cbegin(), keys.cend()); + // Reset the offsetMapper object to revert to its initialized state after loading + auto rs = embCache->ResetOffsetMappers(); + if (rs != 0) { + throw runtime_error("embCache->ResetOffsetMappers failed, err code: " + to_string(rc)); + } } -int EmbeddingDDR::LoadHashMap(const string& savePath) +void EmbeddingDDR::LoadKey(const string &savePath, vector &keys) { stringstream ss; ss << savePath << "/" << name << "/key/slice.data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } size_t fileSize = 0; try { - fileSize = fileSystemPtr->GetFileSize(ss.str()); + fileSize = fileSystemPtr_->GetFileSize(ss.str()); } catch (exception& e) { - LOG_ERROR("open file {} failed:{}", ss.str(), strerror(errno)); - return -1; + string errMsg = StringFormat("open file failed:%s, error code:%d", ss.str().c_str(), strerror(errno)); + throw runtime_error(errMsg); } if (fileSize >= FILE_MAX_SIZE) { - LOG_ERROR("file {} size = {} is too big", ss.str(), fileSize); - return -1; + string errMsg = StringFormat("file:%s, size:%d is too big", ss.str().c_str(), fileSize); + throw runtime_error(errMsg); } - int64_t* buf = static_cast(malloc(fileSize)); + // 暂时向HBM兼容,转成int64_t,后续再归一key类型为uint64_t + auto buf = static_cast(malloc(fileSize)); if (buf == nullptr) { - LOG_ERROR("malloc failed: {}", strerror(errno)); - return -1; + string errMsg = StringFormat("malloc buffer failed, error code:%d", strerror(errno)); + throw runtime_error(errMsg); + } + ssize_t result = fileSystemPtr_->Read(ss.str(), reinterpret_cast(buf), fileSize); + if (result == -1) { + free(static_cast(buf)); + string errMsg = StringFormat("read buffer failed, error code:%d", strerror(errno)); + throw runtime_error(errMsg); + } + if (result != fileSize) { + free(static_cast(buf)); + throw runtime_error(StringFormat("Error: Load keys failed. Expected to read %d bytes, " + "but actually read %d bytes to file %s.", fileSize, result, ss.str().c_str())); } - fileSystemPtr->Read(ss.str(), reinterpret_cast(buf), fileSize); - - size_t loadKeySize = fileSize / sizeof(int64_t); - // key优先加载至device - loadOffset.clear(); hostLoadOffset.clear(); - int keyCount = 0; - for (int i = 0; i < loadKeySize; i = i + 1) { + size_t loadKeySize = fileSize / sizeof(int64_t); + for (size_t i = 0; i < loadKeySize; i++) { + // 分配到不同的卡 if (buf[i] % rankSize_ != rankId_) { continue; } - if (keyCount > devVocabSize + hostVocabSize) { - LOG_ERROR("load key size exceeds the sum of device vocab size and host vocab size: {}", strerror(errno)); - return -1; - } else if (keyCount < devVocabSize) { - loadOffset.push_back(i); - devOffset2Key[keyCount] = buf[i]; - } else { - hostLoadOffset.push_back(i); - } - keyOffsetMap[buf[i]] = keyCount; - keyCount++; + hostLoadOffset.emplace_back(i); + keys.emplace_back(static_cast(buf[i])); } - maxOffset = keyOffsetMap.size(); free(static_cast(buf)); - return 0; + LOG_DEBUG("load key done, table:{}", name); } -void EmbeddingDDR::LoadEmbAndOptim(const string& savePath) +void EmbeddingDDR::LoadEmbedding(const string &savePath, vector> &embeddings) { + // must init first + for (size_t i = 0; i < hostLoadOffset.size(); i++) { + vector tmp(embSize_); + embeddings.emplace_back(tmp); + } + stringstream ss; ss << savePath << "/" << name; + stringstream embedStream; + embedStream << ss.str() << "/" << "embedding/slice.data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + ssize_t res = fileSystemPtr_->Read(embedStream.str(), embeddings, 0, hostLoadOffset, embSize_); + LOG_DEBUG("load embedding done, table:{}, read bytes:{}", name, res); +} - HostEmb *hostEmbs = Singleton::GetInstance(); - HostEmbTable &table = hostEmbs->GetEmb(name); - if (table.embData.empty()) { - LOG_ERROR("hostEmb data is empty"); +void EmbeddingDDR::LoadOptimizerSlot(const string &savePath, vector> &optimizerSlots) +{ + if (optimParams.size() == 0) { + LOG_DEBUG("optimizer has no slot data to load"); return; } - // 读embedding - stringstream embedStream; - embedStream << ss.str() << "/" << "embedding/slice.data"; - ssize_t res = fileSystemPtr->Read(embedStream.str(), table.embData, 0, hostLoadOffset, embSize_); + // must init first + for (size_t i = 0; i < hostLoadOffset.size(); i++) { + vector tmp(extEmbSize_ - embSize_); + optimizerSlots.emplace_back(tmp); + } + + stringstream ss; + ss << savePath << "/" << name; - // 读optim - int64_t optimIndex = 1; + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + int64_t slotIdx = 0; for (const auto ¶m: optimParams) { stringstream paramStream; paramStream << ss.str() << "/" << optimName + "_" + param << "/slice.data"; - ssize_t res = fileSystemPtr->Read(paramStream.str(), table.embData, optimIndex, hostLoadOffset, embSize_); - optimIndex ++; + ssize_t res = fileSystemPtr_->Read(paramStream.str(), optimizerSlots, slotIdx, hostLoadOffset, embSize_); + slotIdx++; + LOG_DEBUG("load optimizer slot, table:{}, slot:{}, read bytes:{}", name, param, res); } + + LOG_DEBUG("load optimizer slot done, table:{}", name); } +void EmbeddingDDR::Save(const string& savePath) +{ + SyncLatestEmbedding(); + vector keys; + vector> embeddings; + vector> optimizerSlots; + + auto step = GetStepFromPath(savePath); + embCache->GetEmbTableInfos(name, keys, embeddings, optimizerSlots); + + SaveKey(savePath, keys); + SaveEmbedding(savePath, embeddings); + SaveOptimizerSlot(savePath, optimizerSlots, keys.size()); +} + +void EmbeddingDDR::SyncLatestEmbedding() +{ + // 导出host记录的存在于npu的embedding + std::vector> koVec; + int rc = embCache->ExportDeviceKeyOffsetPairs(name, koVec); + if (rc != ock::ctr::H_OK) { + string errMsg = StringFormat("ExportDeviceKeyOffsetPairs failed, table:%s, error code:%d", name.c_str(), rc); + throw std::invalid_argument(errMsg); + } + std::vector swapOutKeys; + for (const auto& p : koVec) { + swapOutKeys.push_back(p.first); + } + LOG_DEBUG("save swapOutKeys.size:{}, table:{}", swapOutKeys.size(), name); + + // 接收python save接口发送的卡内embedding + auto size = hdTransfer->RecvAcl(TransferChannel::SAVE_D2H, TRAIN_CHANNEL_ID, name, 0, -1); + LOG_DEBUG("save acltdtGetDatasetSize, size: {}, table:{}", size, name); + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[name][0], 0); + if (aclData == nullptr) { + throw runtime_error("Acl get tensor data from dataset failed."); + } + auto* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + + if (ssdVocabSize == 0) { + // 在保存之前先更新host的embedding + rc = embCache->EmbeddingUpdate(name, swapOutKeys, ptr); + if (rc != ock::ctr::H_OK) { + string errMsg = StringFormat("EmbeddingUpdate failed, table:%s, error code:%d", name.c_str(), rc); + throw std::invalid_argument(errMsg); + } + } else { + // 在保存之前先更新ddr和ssd的embedding + HBMSwapOutInfo info; + cacheManager_->ProcessSwapOutKeys(name, swapOutKeys, info); + vector swapOutAddrs; + rc = embCache->EmbeddingLookupAddrs(name, info.swapOutDDRKeys, swapOutAddrs); + if (rc != ock::ctr::H_OK) { + string errMsg = StringFormat("EmbeddingLookupAddrs failed, table:%s, error code:%d", name.c_str(), rc); + throw std::invalid_argument(errMsg); + } + uint32_t extEmbeddingSize = embInfo_.extEmbeddingSize; + uint32_t memSize = extEmbeddingSize * sizeof(float); + // DDR更新 +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(swapOutAddrs, info, ptr, extEmbeddingSize, memSize) + for (uint64_t i = 0; i < swapOutAddrs.size(); i++) { + int errCode = memcpy_s( + swapOutAddrs[i], memSize, ptr + info.swapOutDDRAddrOffs[i] * extEmbeddingSize, memSize); + if (errCode != 0) { + string errMsg = StringFormat("memcpy_s failed, table:%s, error code:%d", name.c_str(), errCode); + throw std::invalid_argument(errMsg); + } + } + cacheManager_->UpdateL3StorageEmb(name, ptr, embInfo_.extEmbeddingSize, info.swapOutL3StorageKeys, + info.swapOutL3StorageAddrOffs); + } +} -int EmbeddingDDR::SaveKey(const string& savePath) +void EmbeddingDDR::SaveKey(const string& savePath, vector& keys) { stringstream ss; ss << savePath << "/" << name << "/key/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - - hostKey.clear(); - hostOffset.clear(); - deviceKey.clear(); - deviceOffset.clear(); - - for (const auto& it: keyOffsetMap) { - if (it.second >= devVocabSize) { - hostKey.push_back(it.first); - hostOffset.push_back(it.second); - } else { - deviceKey.push_back(it.first); - deviceOffset.push_back(it.second); - } - } + // 暂时向HBM兼容,转成int64_t,后续再归一key类型为uint64_t + vector keysCompat(keys.cbegin(), keys.cend()); - ssize_t res = fileSystemPtr->Write(ss.str(), reinterpret_cast(hostKey.data()), - static_cast(hostKey.size() * sizeof(int64_t))); - if (res == -1) { - return -1; + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); } - ssize_t res2 = fileSystemPtr->Write( - ss.str(), reinterpret_cast(deviceKey.data()), - static_cast(deviceKey.size() * sizeof(int64_t)) - ); - if (res2 == -1) { - return -1; + ssize_t res = fileSystemPtr_->Write(ss.str(), reinterpret_cast(keysCompat.data()), + static_cast(keys.size() * sizeof(int64_t))); + if (res == -1) { + throw runtime_error("save key failed!"); } - return 0; } -void EmbeddingDDR::SaveEmbData(const string& savePath) +void EmbeddingDDR::SaveEmbedding(const string& savePath, vector>& embeddings) { stringstream ss; ss << savePath << "/" << name << "/embedding/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - vector attribute; - fileSystemPtr->Write(ss.str(), embContent, embSize_ * sizeof(float)); + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + ssize_t writeBytesNum = fileSystemPtr_->Write(ss.str(), embeddings, embSize_); + ssize_t expectWriteBytes = embeddings.size() * embSize_ * sizeof(float); + if (writeBytesNum != expectWriteBytes) { + string errMsg = StringFormat("Save embedding failed, write expect:%ld, actual:%ld, path:%s .", + expectWriteBytes, writeBytesNum, savePath.c_str()); + throw runtime_error(errMsg); + } } -void EmbeddingDDR::SaveOptimData(const string& savePath) +void EmbeddingDDR::SaveOptimizerSlot(const string& savePath, vector>& optimizerSlots, size_t keySize) { - for (const auto &content: optimContentMap) { + if (optimizerSlots.size() == 0) { + LOG_DEBUG("optimizer has no slot data to save"); + return; + } + + if (optimizerSlots.size() != keySize) { + string errMsg = StringFormat("optimizer slot data size not equal to key size, " + "optimizerSlots.size:%d, keySize:%d", + optimizerSlots.size(), keySize); + throw runtime_error(errMsg); + } + + size_t slotIdx = 0; + for (const auto &slotName: optimParams) { stringstream ss; - ss << savePath << "/" << name << "/" << optimName + "_" + content.first << "/"; + ss << savePath << "/" << name << "/" << optimName + "_" + slotName << "/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - vector attribute; - fileSystemPtr->Write(ss.str(), content.second, embSize_ * sizeof(float)); - } -} - -void EmbeddingDDR::SaveEmbAndOptim(const string& savePath) -{ - HostEmb *hostEmbs = Singleton::GetInstance(); - HostEmbTable &table = hostEmbs->GetEmb(name); - if (table.embData.empty()) { - LOG_ERROR("host embedding data is empty"); - } - embContent.clear(); - for (const string ¶m: optimParams) { - optimContentMap[param].clear(); - } - for (int64_t &offset: hostOffset) { - embContent.push_back(table.embData[offset - devVocabSize].data()); - int optim_param_count = 1; - for (const string ¶m: optimParams) { - optimContentMap[param].push_back(table.embData[offset - devVocabSize].data() + - sizeof(float) * embSize_ * optim_param_count); - optim_param_count++; + vector> slotData; + for (const auto &data: optimizerSlots) { + vector tmp(data.cbegin() + slotIdx * embSize_, data.cbegin() + (slotIdx+1) * embSize_); + slotData.emplace_back(tmp); } + ssize_t writeBytesNum = fileSystemPtr_->Write(ss.str(), slotData, embSize_); + ssize_t expectWriteBytes = slotData.size() * embSize_ * sizeof(float); + if (writeBytesNum != expectWriteBytes) { + string errMsg = StringFormat("save optimizer slot failed, write expect:%d, actual:%d, path:%s", + expectWriteBytes, writeBytesNum, savePath.c_str()); + throw runtime_error(errMsg); + } + + slotIdx++; } - SaveEmbData(savePath); - SaveOptimData(savePath); } - vector EmbeddingDDR::GetDeviceOffset() { - return deviceOffset; + throw runtime_error("GetDeviceOffset deprecated in ddr/ssd mode"); } void EmbeddingDDR::SetOptimizerInfo(OptimizerInfo& optimizerInfo) { optimName = optimizerInfo.optimName; optimParams = optimizerInfo.optimParams; - for (const string ¶m: optimParams) { - optimContentMap[param] = vector{}; - } } + void EmbeddingDDR::SetCacheManager(CacheManager *cm) { + LOG_DEBUG("set CacheManager"); cacheManager_ = cm; } -void EmbeddingDDR::AddKeyFreqInfo(const emb_key_t& key, RecordType type) -{ - if (!isSSDEnabled_) { - return; - } - cacheManager_->PutKey(name, key, type); -} - -void EmbeddingDDR::RefreshFreqInfoWithSwap() -{ - if (!isSSDEnabled_) { - return; - } - // 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key - LOG_DEBUG("RefreshFreqInfoWithSwap, table:{}, oldSwap Size:{}", name, oldSwap.size()); - vector enterDDRKeys; - for (auto keyPair : oldSwap) { - enterDDRKeys.emplace_back(keyPair.first); - } - cacheManager_->RefreshFreqInfoCommon(name, enterDDRKeys, TransferType::HBM_2_DDR); - cacheManager_->RefreshFreqInfoCommon(name, ddr2HbmKeys, TransferType::DDR_2_HBM); - - AddCacheManagerTraceLog(); -} - -/// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 -void EmbeddingDDR::AddCacheManagerTraceLog() const -{ - if (Logger::GetLevel() != Logger::TRACE) { - return; - } - auto& hostMap = keyOffsetMap; - auto& devSize = devVocabSize; - auto iter = cacheManager_->ddrKeyFreqMap.find(name); - if (iter == cacheManager_->ddrKeyFreqMap.end()) { - throw runtime_error("table not in ddrKeyFreqMap"); - } - auto &lfu = iter->second; - const auto& lfuTab = lfu.GetFreqTable(); - if (lfuTab.empty()) { - return; - } - size_t tableKeyInDdr = 0; - vector ddrKeys; // 获取hostHashMap中保存在DDR的key - for (const auto& item : hostMap) { - if (item.second < devSize) { - continue; - } - ddrKeys.emplace_back(item.first); - ++tableKeyInDdr; - } - vector lfuKeys; - for (const auto& it : lfuTab) { - lfuKeys.emplace_back(it.first); - } - std::sort(ddrKeys.begin(), ddrKeys.end()); - std::sort(lfuKeys.begin(), lfuKeys.end()); - std::string ddrKeysString = VectorToString(ddrKeys); - std::string lfuKeysString = VectorToString(lfuKeys); - if (ddrKeysString != lfuKeysString) { - LOG_ERROR("swap HBM with DDR step error, key string not equal, table:{}, ddrKeysString:{}, lfuKeysString:{}", - name, ddrKeysString, lfuKeysString); - } else { - LOG_INFO("swap HBM with DDR step OK, table:{}, ddrKeysString == lfuKeysString, string length:{}", - name, lfuKeysString.length()); - } - - LOG_INFO("swap HBM with DDR step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", - name, tableKeyInDdr, lfu.keyTable.size()); -} - TableInfo EmbeddingDDR::GetTableInfo() { TableInfo ti = { @@ -607,42 +366,26 @@ TableInfo EmbeddingDDR::GetTableInfo() .devVocabSize=devVocabSize, .maxOffset=maxOffset, .keyOffsetMap=keyOffsetMap, - .evictDevPos=evictDevPos, - .evictHostPos=evictHostPos, }; return ti; } -void EmbeddingDDR::RefreshFreqInfoAfterLoad() +void EmbeddingDDR::SetHDTransfer(HDTransfer *hdTransfer) { - vector h2d; - vector d2h; - - for (const auto& it: cacheManager_->ddrKeyFreqMap[name].keyTable) { - auto key = it.first; - auto iter = keyOffsetMap.find(key); - if (iter == keyOffsetMap.end()) { - throw runtime_error("ddrKeyFreqMap key not in keyOffsetMap"); - } - auto offset = iter->second; - if (offset < devVocabSize) { - d2h.emplace_back(key); - } - } - for (const auto& it: cacheManager_->excludeDDRKeyCountMap[name]) { - auto key = it.first; - auto iter = keyOffsetMap.find(key); - if (iter == keyOffsetMap.end()) { - continue; - } - auto offset = iter->second; - if (offset >= devVocabSize) { - h2d.emplace_back(key); - } - } + this->hdTransfer = hdTransfer; +} + +void EmbeddingDDR::SetEmbCache(ock::ctr::EmbCacheManagerPtr embCache) +{ + this->embCache = embCache; +} - cacheManager_->RefreshFreqInfoCommon(name, h2d, TransferType::HBM_2_DDR); - cacheManager_->RefreshFreqInfoCommon(name, d2h, TransferType::DDR_2_HBM); +void EmbeddingDDR::BackUpTrainStatus() +{ + embCache->BackUpTrainStatus(name); +} - LOG_DEBUG("RefreshFreqInfoAfterLoad done"); +void EmbeddingDDR::RecoverTrainStatus() +{ + embCache->RecoverTrainStatus(name); } diff --git a/src/core/emb_table/embedding_ddr.h b/src/core/emb_table/embedding_ddr.h index b2a461d8edf1a850a28ab28c34e33c42df54ce73..26d85e606b37b414ba62fc41b9782cc4b30fceb9 100644 --- a/src/core/emb_table/embedding_ddr.h +++ b/src/core/emb_table/embedding_ddr.h @@ -34,57 +34,51 @@ public: virtual int64_t capacity() const; - virtual std::vector FindOffset(const vector& keys, - size_t batchId, int channelId, - std::vector& swapPos); + virtual void EvictKeys(const vector& keys); - emb_key_t FindOffsetHelper(const emb_key_t& key, int channelId); + void Load(const string& savePath, map>& trainKeySet); - void UpdateBatchId(const vector& keys, size_t currentBatchId); + void LoadKey(const string& savePath, vector& keys); - emb_key_t FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t batchId, std::vector& swapPos); + void LoadEmbedding(const string& savePath, vector>& embeddings); - virtual void EvictKeys(const vector& keys); + void LoadOptimizerSlot(const string& savePath, vector>& optimizerSlots); -// std::vector lookUpVec; // 查询结果 + void Save(const string& savePath); - virtual void ClearLookupAndSwapOffset(); + void SyncLatestEmbedding(); - void SetStartCount(); + void SaveKey(const string& savePath, vector& keys); - void Load(const string& savePath); + void SaveEmbedding(const string& savePath, vector>& embeddings); - void Save(const string& savePath); + void SaveOptimizerSlot(const string& savePath, vector>& optimizerSlots, size_t keySize); vector GetDeviceOffset(); void SetOptimizerInfo(OptimizerInfo& optimizerInfo); - void RefreshFreqInfoWithSwap(); - - void AddKeyFreqInfo(const emb_key_t& key, RecordType type); - void SetCacheManager(CacheManager *cm); - void AddCacheManagerTraceLog() const; - TableInfo GetTableInfo(); - void RefreshFreqInfoAfterLoad(); + void SetHDTransfer(HDTransfer* hdTransfer); -GTEST_PRIVATE: - - int LoadHashMap(const string& savePath); + void LoadKey(const string& savePath); void LoadEmbAndOptim(const string& savePath); - int SaveKey(const string& savePath); + void SaveKey(const string& savePath); void SaveEmbData(const string &savePath); void SaveOptimData(const string& savePath); void SaveEmbAndOptim(const string& savePath); + void SetEmbCache(ock::ctr::EmbCacheManagerPtr embCache); - void EvictDeleteEmb(const vector& keys); + void BackUpTrainStatus(); + void RecoverTrainStatus(); - std::vector devOffset2Key; +GTEST_PRIVATE: + + void EvictDeleteEmb(const vector& keys); size_t maxOffsetOld { 0 }; std::vector evictPosChange; @@ -92,32 +86,16 @@ GTEST_PRIVATE: std::vector> devOffset2KeyOld; std::vector> oldSwap; // (old on dev, old on host) - /* - * HBM与DDR换入换出时,已存在于DDR且要转移到HBM的key(不包含新key); 用于SSD模式 - * (区别于oldSwap: pair.second为已存在于DDR key + 换入换出前映射到DDR的新key) - */ - std::vector ddr2HbmKeys; - std::vector devOffset2Batch; // has -1 - - /** - * 记录HBM上查找空位的当前位置 - * 值域为[0, devVocabSize] - **/ - size_t currentUpdatePos; - size_t currentUpdatePosStart; // 记录HBM上查找空位的起始位置 - - vector hostKey; - vector hostOffset; - vector deviceKey; - vector deviceOffset; - vector embContent; std::string optimName; std::vector optimParams; - std::map> optimContentMap; vector hostLoadOffset; + + HDTransfer *hdTransfer = nullptr; + ock::ctr::EmbCacheManagerPtr embCache = nullptr; + int deviceId = -1; }; } diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index 9fd265464a1f7a1eb547b271557bbba77ae093ae..703d08ad99b9644c13634917c000cccac8aedc7c 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -17,7 +17,6 @@ See the License for the specific language governing permissions and #include "utils/logger.h" #include "utils/singleton.h" #include "hd_transfer/hd_transfer.h" -#include "file_system/file_system_handler.h" #include "utils/common.h" using namespace MxRec; @@ -27,7 +26,7 @@ EmbeddingDynamic::EmbeddingDynamic() } EmbeddingDynamic::EmbeddingDynamic(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) - : EmbeddingTable(info, rankInfo, inSeed) + : EmbeddingTable(info, rankInfo, inSeed), deviceId(rankInfo.deviceId) { if (isDynamic_) { auto ret = aclrtSetDevice(static_cast(rankInfo.deviceId)); @@ -78,7 +77,7 @@ void EmbeddingDynamic::Key2Offset(std::vector& keys, int channel) int64_t EmbeddingDynamic::capacity() const { - return capacity_; + return capacity_.load(); } int64_t EmbeddingDynamic::GetEmptyEmbeddingAddress() @@ -104,7 +103,7 @@ void EmbeddingDynamic::MallocEmbeddingBlock(int embNum) float *embAddr = static_cast(block) + (i * extEmbSize_); embeddingList_.push_back(embAddr); } - capacity_ += embNum; + capacity_.fetch_add(embNum); } void EmbeddingDynamic::RandomInit(void* addr, size_t embNum) @@ -128,23 +127,17 @@ void EmbeddingDynamic::RandomInit(void* addr, size_t embNum) void EmbeddingDynamic::Save(const string& savePath) { - int res = SaveKey(savePath); - if (res == -1) { - throw std::runtime_error("save key failed!"); - } + SaveKey(savePath); SaveEmbAndOptim(savePath); } -int EmbeddingDynamic::SaveKey(const string& savePath) +void EmbeddingDynamic::SaveKey(const string& savePath) { stringstream ss; ss << savePath << "/" << name << "/key/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - deviceKey.clear(); embAddress.clear(); @@ -153,12 +146,19 @@ int EmbeddingDynamic::SaveKey(const string& savePath) embAddress.push_back(it.second); } - ssize_t res = fileSystemPtr->Write(ss.str(), reinterpret_cast(deviceKey.data()), - static_cast(deviceKey.size() * sizeof(int64_t))); + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + size_t writeSize = static_cast(deviceKey.size() * sizeof(int64_t)); + ssize_t res = fileSystemPtr_->Write(ss.str(), reinterpret_cast(deviceKey.data()), writeSize); if (res == -1) { - return -1; + throw runtime_error(StringFormat("Error: Save keys failed. " + "An error occurred while writing file: %s.", ss.str().c_str())); + } + if (res != writeSize) { + throw runtime_error(StringFormat("Error: Save keys failed. Expected to write %d bytes, " + "but actually write %d bytes to file %s.", writeSize, res, ss.str().c_str())); } - return 0; } void EmbeddingDynamic::SaveEmbAndOptim(const string& savePath) @@ -194,31 +194,31 @@ void EmbeddingDynamic::SaveEmbData(const string& savePath) MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - fileSystemPtr->WriteEmbedding(ss.str(), embSize_, embAddress, rankId_); + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + fileSystemPtr_->WriteEmbedding(ss.str(), embSize_, embAddress, deviceId); } void EmbeddingDynamic::SaveOptimData(const string &savePath) { + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + for (const auto &content: optimAddressMap) { stringstream ss; ss << savePath << "/" << name << "/" << optimName + "_" + content.first << "/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - fileSystemPtr->WriteEmbedding(ss.str(), embSize_, content.second, rankId_); + fileSystemPtr_->WriteEmbedding(ss.str(), embSize_, content.second, deviceId); } } -void EmbeddingDynamic::Load(const string& savePath) +void EmbeddingDynamic::Load(const string& savePath, map>& trainKeySet) { - int res = LoadKey(savePath); - if (res == -1) { - throw std::runtime_error("load key failed!"); - } + LoadKey(savePath); LoadEmbAndOptim(savePath); } @@ -227,52 +227,56 @@ void EmbeddingDynamic::LoadEmbAndOptim(const string& savePath) stringstream ss; ss << savePath << "/" << name; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - // 读embedding stringstream embedStream; embedStream << ss.str() << "/" << "embedding/slice.data"; + + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } EmbeddingSizeInfo embeddingSizeInfo = {embSize_, extEmbSize_}; - fileSystemPtr->ReadEmbedding(savePath, embeddingSizeInfo, firstAddress, rankId_, loadOffset); + fileSystemPtr_->ReadEmbedding(embedStream.str(), embeddingSizeInfo, firstAddress, rankId_, loadOffset); // 读optim int optimIndex = 1; for (const auto ¶m: optimParams) { stringstream paramStream; paramStream << ss.str() << "/" << optimName + "_" + param << "/slice.data"; - fileSystemPtr->ReadEmbedding(paramStream.str(), embeddingSizeInfo, - firstAddress + optimIndex * embSize_ * sizeof(float), rankId_, loadOffset); + fileSystemPtr_->ReadEmbedding(paramStream.str(), embeddingSizeInfo, + firstAddress + optimIndex * embSize_ * sizeof(float), deviceId, loadOffset); optimIndex++; } } -int EmbeddingDynamic::LoadKey(const string& savePath) +void EmbeddingDynamic::LoadKey(const string& savePath) { stringstream ss; ss << savePath << "/" << name << "/key/slice.data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - - size_t fileSize = 0; - try { - fileSize = fileSystemPtr->GetFileSize(ss.str()); - } catch (exception& e) { - LOG_ERROR("open file {} failed:{}", ss.str(), strerror(errno)); - return -1; + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); } + size_t fileSize = fileSystemPtr_->GetFileSize(ss.str()); if (fileSize >= FILE_MAX_SIZE) { - LOG_ERROR("file {} size = {} is too big", ss.str(), fileSize); - return -1; + throw runtime_error(StringFormat("Error: Load keys failed. " + "file %s size %d is too big.", ss.str().c_str(), fileSize)); } int64_t* buf = static_cast(malloc(fileSize)); if (buf == nullptr) { - LOG_ERROR("malloc failed: {}", strerror(errno)); - return -1; + throw runtime_error(StringFormat("Error: Load keys failed. " + "failed to allocate %d bytes using malloc.", fileSize)); + } + + ssize_t res = fileSystemPtr_->Read(ss.str(), reinterpret_cast(buf), fileSize); + if (res == -1) { + throw runtime_error(StringFormat("Error: Load keys failed. " + "An error occurred while reading file: %s.", ss.str().c_str())); + } + if (res != fileSize) { + throw runtime_error(StringFormat("Error: Load keys failed. Expected to read %d bytes, " + "but actually read %d bytes to file %s.", fileSize, res, ss.str().c_str())); } - fileSystemPtr->Read(ss.str(), reinterpret_cast(buf), fileSize); size_t loadKeySize = fileSize / sizeof(int64_t); @@ -289,7 +293,8 @@ int EmbeddingDynamic::LoadKey(const string& savePath) void *newBlock = nullptr; aclError ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - throw runtime_error(StringFormat("aclrtMalloc failed, ret=%d", ret).c_str()); + throw runtime_error(StringFormat("Error: in dynamic expansion mode, " + "aclrtMalloc failed, malloc size: %d.", datasetSize)); } // 此处的 newBlock -> first address; // 对key_offset map 进行一个恢复操作 @@ -303,5 +308,4 @@ int EmbeddingDynamic::LoadKey(const string& savePath) maxOffset = keyOffsetMap.size(); free(static_cast(buf)); - return 0; } diff --git a/src/core/emb_table/embedding_dynamic.h b/src/core/emb_table/embedding_dynamic.h index 2c867530259d2d2db0ab88b79a2c101172eec142..5cf497180c14a7304170c265b087d539d10d4f98 100644 --- a/src/core/emb_table/embedding_dynamic.h +++ b/src/core/emb_table/embedding_dynamic.h @@ -35,7 +35,7 @@ public: virtual int64_t capacity() const; - void Load(const string& savePath); + void Load(const string& savePath, map>& trainKeySet); void Save(const string& savePath); @@ -48,13 +48,13 @@ private: void MallocEmbeddingBlock(int embNum); - int SaveKey(const string& savePath); + void SaveKey(const string& savePath); void SaveEmbAndOptim(const string& savePath); void SetOptimizerInfo(OptimizerInfo& optimizerInfo); - int LoadKey(const string& savePath); + void LoadKey(const string& savePath); void LoadEmbAndOptim(const string& savePath); @@ -74,6 +74,7 @@ private: std::string optimName; std::vector optimParams; std::map> optimAddressMap; + int deviceId = -1; int64_t firstAddress; }; diff --git a/src/core/emb_table/embedding_mgmt.cpp b/src/core/emb_table/embedding_mgmt.cpp index 2c2f9e398c2c9c6f8a50b6342d3ebc0d00ddf61c..d889cdba58ea51f95c743448ea29f19c77c56cc2 100644 --- a/src/core/emb_table/embedding_mgmt.cpp +++ b/src/core/emb_table/embedding_mgmt.cpp @@ -12,11 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "emb_table/embedding_mgmt.h" + +#include + #include "emb_table/embedding_static.h" #include "emb_table/embedding_dynamic.h" #include "emb_table/embedding_ddr.h" +#include "file_system/file_system_handler.h" #include "utils/logger.h" using namespace MxRec; @@ -25,8 +28,7 @@ EmbeddingMgmt::EmbeddingMgmt() { } -void EmbeddingMgmt::Init(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues, int seed) +void EmbeddingMgmt::Init(const RankInfo& rInfo, const vector& eInfos, int seed) { for (size_t i = 0; i < eInfos.size(); ++i) { if (rInfo.isDDR) { @@ -54,17 +56,7 @@ void EmbeddingMgmt::Key2Offset(const std::string& name, std::vector& size_t EmbeddingMgmt::GetMaxOffset(const std::string& name) { - embeddings[name]->GetMaxOffset(); -} - -void EmbeddingMgmt::LoadMaxOffset(OffsetMemT& loadData) -{ - LOG_ERROR("load max offset"); -} - -void EmbeddingMgmt::LoadKeyOffsetMap(KeyOffsetMemT& loadData) -{ - LOG_ERROR("load key offset"); + return embeddings[name]->GetMaxOffset(); } std::map EmbeddingMgmt::GetMaxOffset() @@ -85,7 +77,7 @@ KeyOffsetMemT EmbeddingMgmt::GetKeyOffsetMap() return keyOffsetMap; } -void EmbeddingMgmt::EvictKeys(const string& name, const vector& keys) +void EmbeddingMgmt::EvictKeys(const string& name, const vector& keys) { LOG_ERROR("evict keys for {}", name); if (keys.size() != 0) { @@ -94,7 +86,7 @@ void EmbeddingMgmt::EvictKeys(const string& name, const vector& keys) embeddings[name]->EvictInitDeviceEmb(); } -void EmbeddingMgmt::EvictKeysCombine(const vector& keys) +void EmbeddingMgmt::EvictKeysCombine(const vector& keys) { if (keys.size() != 0) { for (auto& table: embeddings) { @@ -117,47 +109,47 @@ int64_t EmbeddingMgmt::GetCapacity(const std::string &name) return embeddings[name]->capacity(); } -void EmbeddingMgmt::FindOffset(const std::string& name, const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channel) +void EmbeddingMgmt::Load(const string& name, const string& filePath, + map>& trainKeySet) { - return embeddings[name]->FindOffset(keys, currentBatchId, keepBatchId, channel); + embeddings[name]->SetFileSystemPtr(filePath); + embeddings[name]->Load(filePath, trainKeySet); + embeddings[name]->UnsetFileSystemPtr(); } -const std::vector& EmbeddingMgmt::GetMissingKeys(const std::string& name) -{ - return embeddings[name]->GetMissingKeys(); -} - -void EmbeddingMgmt::ClearMissingKeys(const std::string& name) -{ - return embeddings[name]->ClearMissingKeys(); -} - -std::shared_ptr EmbeddingMgmt::GetTable(const string& name) -{ - auto it = embeddings.find(name); - if (it == embeddings.end()) { - LOG_ERROR("table not found"); - } - return std::dynamic_pointer_cast(it->second); -} - -void EmbeddingMgmt::Load(const string& filePath) +void EmbeddingMgmt::Load(const string& filePath, map>& trainKeySet) { for (auto& tablePair: embeddings) { - tablePair.second->Load(filePath); + tablePair.second->SetFileSystemPtr(filePath); + tablePair.second->Load(filePath, trainKeySet); + tablePair.second->UnsetFileSystemPtr(); } } void EmbeddingMgmt::Save(const string& name, const string& filePath) { - return embeddings[name]->Save(filePath); + embeddings[name]->SetFileSystemPtr(filePath); + embeddings[name]->Save(filePath); + embeddings[name]->UnsetFileSystemPtr(); } void EmbeddingMgmt::Save(const string& filePath) { for (auto& tablePair: embeddings) { - tablePair.second->Save(filePath); + tablePair.second->SetFileSystemPtr(filePath); + } + // use multi-thread to prevent receiving save_d2h blocked when table order different between cpp and python + vector> futures; + for (auto& tablePair: embeddings) { + futures.emplace_back( + std::async(std::launch::async, [table = tablePair.second, filePath] { table->Save(filePath); })); + } + for (auto& f: futures) { + f.get(); // get() will repost exception if happened + } + + for (auto& tablePair: embeddings) { + tablePair.second->UnsetFileSystemPtr(); } } @@ -175,18 +167,6 @@ void EmbeddingMgmt::SetOptimizerInfo(const string& name, OptimizerInfo& optimize embeddings[name]->SetOptimizerInfo(optimizerInfo); } -EmbHashMemT EmbeddingMgmt::GetEmbHashMaps() -{ - EmbHashMemT EmbHashMaps; - for (auto& tablePair: embeddings) { - EmbHashMaps[tablePair.first].hostHashMap = tablePair.second ->GetKeyOffsetMap(); - EmbHashMaps[tablePair.first].devVocabSize = tablePair.second ->GetDevVocabSize(); - EmbHashMaps[tablePair.first].hostVocabSize = tablePair.second ->GetHostVocabSize(); - EmbHashMaps[tablePair.first].maxOffset = tablePair.second ->GetMaxOffset(); - } - return EmbHashMaps; -} - OffsetMapT EmbeddingMgmt::GetLoadOffsets() { OffsetMapT AllLoadOffsets; @@ -203,25 +183,30 @@ void EmbeddingMgmt::SetCacheManagerForEmbTable(CacheManager* cacheManager) } } -void EmbeddingMgmt::EnableSSD() +void EmbeddingMgmt::SetHDTransferForEmbTable(HDTransfer* hdTransfer) +{ + for (auto& table: embeddings) { + table.second->SetHDTransfer(hdTransfer); + } +} + +void EmbeddingMgmt::SetEmbCacheForEmbTable(const ock::ctr::EmbCacheManagerPtr& embCache) { for (auto& table: embeddings) { - table.second->EnableSSD(); + table.second->SetEmbCache(embCache); } } -void EmbeddingMgmt::LockSave() +void EmbeddingMgmt::BackUpTrainStatusBeforeLoad() { for (auto& table: embeddings) { - table.second->mutSave_.lock(); + table.second->BackUpTrainStatus(); } - LOG_DEBUG("LockSave"); } -void EmbeddingMgmt::UnLockSave() +void EmbeddingMgmt::RecoverTrainStatus() { for (auto& table: embeddings) { - table.second->mutSave_.unlock(); + table.second->RecoverTrainStatus(); } - LOG_DEBUG("UnLockSave"); } \ No newline at end of file diff --git a/src/core/emb_table/embedding_mgmt.h b/src/core/emb_table/embedding_mgmt.h index 11ed23254f753493edf8add9e20bf6487f43dd92..9dd0e363292f1dae84d618333e4d216c891e4494 100644 --- a/src/core/emb_table/embedding_mgmt.h +++ b/src/core/emb_table/embedding_mgmt.h @@ -34,8 +34,7 @@ public: * @param[in] rInfo 从python侧传过了的rank信息 * @param[in] eInfos 从python侧传过了的embedding表信息 */ - void Init(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues = {}, int seed = 0); + void Init(const RankInfo& rInfo, const vector& eInfos, int seed = 0); /** * 从embedding表中查批量查找key @@ -45,29 +44,18 @@ public: */ void Key2Offset(const std::string& name, std::vector& keys, int channel); - void FindOffset(const std::string& name, const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channel); - /** * 在指定的embedding表中淘汰key * @param[in] name embedding表名 * @param[in] keys 待淘汰的key */ - void EvictKeys(const std::string& name, const vector& keys); + void EvictKeys(const std::string& name, const vector& keys); /** * 在全部的embedding表中淘汰key * @param[in] keys 待淘汰的key */ - void EvictKeysCombine(const vector& keys); - - const std::vector& GetMissingKeys(const std::string& name); - - void ClearMissingKeys(const std::string& name); - - void LoadMaxOffset(OffsetMemT& loadData); - - void LoadKeyOffsetMap(KeyOffsetMemT& loadData); + void EvictKeysCombine(const vector& keys); size_t GetMaxOffset(const std::string& name); @@ -81,12 +69,15 @@ public: static EmbeddingMgmt* Instance(); - std::shared_ptr GetTable(const string& name); + /** + * 加载单个表 + */ + void Load(const string& name, const string& filePath, map>& trainKeySet); /** * 加载所有表 */ - void Load(const string& filePath); + void Load(const string& filePath, map>& trainKeySet); /** * 保存单个表 @@ -98,6 +89,16 @@ public: */ void Save(const string& filePath); + /** + * In estimator mode, when switching from train to eval, backup the training state of all tables. + */ + void BackUpTrainStatusBeforeLoad(); + + /** + * In estimator mode, when switching from eval to train, recover the training state of all tables. + */ + void RecoverTrainStatus(); + /** * 获取所有表对应的DeviceOffsets,该偏移用于python侧保存embedding时抽取key对应的embedding */ @@ -108,8 +109,6 @@ public: */ OffsetMapT GetLoadOffsets(); - EmbHashMemT GetEmbHashMaps(); - /** * 设置某张表的优化器信息 */ @@ -117,11 +116,9 @@ public: void SetCacheManagerForEmbTable(CacheManager* cacheManager); - void EnableSSD(); - - void LockSave(); + void SetHDTransferForEmbTable(HDTransfer* hdTransfer); - void UnLockSave(); + void SetEmbCacheForEmbTable(const ock::ctr::EmbCacheManagerPtr& embCache); private: EmbeddingMgmt(); diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp index 225c90c91d8a9a7508104821ce3a0ae84ca3b6f9..0db152ed8ce0ece19dc1f073ff915abac3fedc1e 100644 --- a/src/core/emb_table/embedding_static.cpp +++ b/src/core/emb_table/embedding_static.cpp @@ -73,22 +73,16 @@ int64_t EmbeddingStatic::capacity() const void EmbeddingStatic::Save(const string& savePath) { - int res = SaveKey(savePath); - if (res == -1) { - throw std::runtime_error("save embedding table failed!"); - } + SaveKey(savePath); } -int EmbeddingStatic::SaveKey(const string& savePath) +void EmbeddingStatic::SaveKey(const string& savePath) { stringstream ss; ss << savePath << "/" << name << "/key/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - deviceKey.clear(); deviceOffset.clear(); @@ -97,48 +91,56 @@ int EmbeddingStatic::SaveKey(const string& savePath) deviceOffset.push_back(it.second); } - ssize_t res = fileSystemPtr->Write(ss.str(), reinterpret_cast(deviceKey.data()), - static_cast(deviceKey.size() * sizeof(int64_t))); + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + + size_t writeSize = static_cast(deviceKey.size() * sizeof(int64_t)); + ssize_t res = fileSystemPtr_->Write(ss.str(), reinterpret_cast(deviceKey.data()), writeSize); if (res == -1) { - return -1; + throw runtime_error(StringFormat("Error: Save keys failed. " + "An error occurred while writing file: %s.", ss.str().c_str())); + } + if (res != writeSize) { + throw runtime_error(StringFormat("Error: Save keys failed. Expected to write %d bytes, " + "but actually write %d bytes to file %s.", writeSize, res, ss.str().c_str())); } - return 0; } -void EmbeddingStatic::Load(const string& savePath) +void EmbeddingStatic::Load(const string& savePath, map>& trainKeySet) { - int res = LoadKey(savePath); - if (res == -1) { - throw std::runtime_error("load embedding table failed!"); - } + LoadKey(savePath); } -int EmbeddingStatic::LoadKey(const string &savePath) +void EmbeddingStatic::LoadKey(const string& savePath) { stringstream ss; ss << savePath << "/" << name << "/key/slice.data"; - unique_ptr fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - - size_t fileSize = 0; - try { - fileSize = fileSystemPtr->GetFileSize(ss.str()); - } catch (exception &e) { - LOG_ERROR("open file {} failed:{}", ss.str(), strerror(errno)); - return -1; + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); } + size_t fileSize = fileSystemPtr_->GetFileSize(ss.str()); if (fileSize >= FILE_MAX_SIZE) { - LOG_ERROR("file {} size = {} is too big", ss.str(), fileSize); - return -1; + throw runtime_error(StringFormat("Error: Load keys failed. " + "file %s size %d is too big.", ss.str().c_str(), fileSize)); } - int64_t* buf = static_cast(malloc(fileSize)); + int64_t* buf = static_cast(malloc(fileSize)); if (buf == nullptr) { - LOG_ERROR("malloc failed: {}", strerror(errno)); - return -1; + throw runtime_error(StringFormat("Error: Load keys failed. " + "failed to allocate %d bytes using malloc.", fileSize)); + } + + ssize_t res = fileSystemPtr_->Read(ss.str(), reinterpret_cast(buf), fileSize); + if (res == -1) { + throw runtime_error(StringFormat("Error: Load keys failed. " + "An error occurred while reading file: %s.", ss.str().c_str())); + } + if (res != fileSize) { + throw runtime_error(StringFormat("Error: Load keys failed. Expected to read %d bytes, " + "but actually read %d bytes to file %s.", fileSize, res, ss.str().c_str())); } - fileSystemPtr->Read(ss.str(), reinterpret_cast(buf), fileSize); size_t loadKeySize = fileSize / sizeof(int64_t); loadOffset.clear(); @@ -152,17 +154,29 @@ int EmbeddingStatic::LoadKey(const string &savePath) } if (loadOffset.size() > devVocabSize) { - LOG_ERROR("load key size exceeds device vocab size: {}", strerror(errno)); - return -1; + free(static_cast(buf)); + throw runtime_error(StringFormat("Error: Load keys failed. Load key size :%d exceeds device vocab size: %d.", + loadOffset.size(), devVocabSize)); } maxOffset = keyOffsetMap.size(); - free(static_cast(buf)); - return 0; } vector EmbeddingStatic::GetDeviceOffset() { return deviceOffset; -} \ No newline at end of file +} + +void EmbeddingStatic::BackUpTrainStatus() +{ + keyOffsetMapBackUp = keyOffsetMap; +} + +void EmbeddingStatic::RecoverTrainStatus() +{ + if (keyOffsetMapBackUp.size()!=0) { + keyOffsetMap = keyOffsetMapBackUp; + keyOffsetMapBackUp.clear(); + } +} diff --git a/src/core/emb_table/embedding_static.h b/src/core/emb_table/embedding_static.h index 06e24efafd27b1717156c9bd0f73f351d8de2089..6f772e0891a09dd24e05983a125d2e046f01095e 100644 --- a/src/core/emb_table/embedding_static.h +++ b/src/core/emb_table/embedding_static.h @@ -35,16 +35,20 @@ public: virtual int64_t capacity() const; - void Load(const string& savePath); + void Load(const string& savePath, map>& trainKeySet); void Save(const string& savePath); + void BackUpTrainStatus(); + + void RecoverTrainStatus(); + vector GetDeviceOffset(); GTEST_PRIVATE: - int SaveKey(const string& savePath); + void SaveKey(const string& savePath); - int LoadKey(const string& savePath); + void LoadKey(const string& savePath); vector deviceKey; vector deviceOffset; diff --git a/src/core/emb_table/embedding_table.cpp b/src/core/emb_table/embedding_table.cpp index 7cfc125e8bd13919058f7d85f6f4b8ca0c78b41b..12b0137a984b7e76925b783ac27b14e10a166b43 100644 --- a/src/core/emb_table/embedding_table.cpp +++ b/src/core/emb_table/embedding_table.cpp @@ -27,7 +27,7 @@ EmbeddingTable::EmbeddingTable() EmbeddingTable::EmbeddingTable(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) : name(info.name), hostVocabSize(info.hostVocabSize), devVocabSize(info.devVocabSize), - freeSize_(0), maxOffset(0), isDynamic_(rankInfo.useDynamicExpansion), + ssdVocabSize(info.ssdVocabSize), freeSize_(0), maxOffset(0), isDynamic_(rankInfo.useDynamicExpansion), embSize_(info.embeddingSize), extEmbSize_(info.extEmbeddingSize), embInfo_(info), seed_(inSeed), rankId_(rankInfo.rankId), rankSize_(rankInfo.rankSize) { @@ -43,19 +43,6 @@ void EmbeddingTable::Key2Offset(std::vector& keys, int channel) return; } -void EmbeddingTable::FindOffset(const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channelId) -{ - return; -} - -std::vector EmbeddingTable::FindOffset(const vector& keys, - size_t batchId, int channelId, - std::vector& swapPos) -{ - return {}; -} - size_t EmbeddingTable::GetMaxOffset() { return maxOffset; @@ -71,7 +58,7 @@ size_t EmbeddingTable::size() const return maxOffset; } -void EmbeddingTable::EvictKeys(const std::vector& keys) +void EmbeddingTable::EvictKeys(const std::vector& keys) { std::lock_guard lk(mut_); // lock for PROCESS_THREAD size_t keySize = keys.size(); @@ -132,32 +119,15 @@ absl::flat_hash_map EmbeddingTable::GetKeyOffsetMap() return keyOffsetMap; } -void EmbeddingTable::ClearMissingKeys() -{ - missingKeysHostPos_.clear(); -} - -const std::vector& EmbeddingTable::GetMissingKeys() -{ - return missingKeysHostPos_; -} - -void EmbeddingTable::SetStartCount() -{ -} - -void EmbeddingTable::ClearLookupAndSwapOffset() +void EmbeddingTable::SetFileSystemPtr(const string& savePath) { + unique_ptr fileSystemHandler = make_unique(); + fileSystemPtr_ = fileSystemHandler->Create(savePath); } -size_t EmbeddingTable::GetDevVocabSize() +void EmbeddingTable::UnsetFileSystemPtr() { - return devVocabSize; -} - -size_t EmbeddingTable::GetHostVocabSize() -{ - return hostVocabSize; + fileSystemPtr_ = nullptr; } vector EmbeddingTable::GetLoadOffset() @@ -165,7 +135,7 @@ vector EmbeddingTable::GetLoadOffset() return loadOffset; } -void EmbeddingTable::Load(const string& filePath) +void EmbeddingTable::Load(const string& filePath, map>& trainKeySet) { } @@ -173,23 +143,23 @@ void EmbeddingTable::Save(const string& filePath) { } -void EmbeddingTable::MakeDir(const string& dirName) +void EmbeddingTable::BackUpTrainStatus() { - auto fileSystemHandler = make_unique(); - unique_ptr fileSystemPtr = fileSystemHandler->Create(dirName); - fileSystemPtr->CreateDir(dirName); } -void EmbeddingTable::SetCacheManager(CacheManager *cm) +void EmbeddingTable::RecoverTrainStatus() { } -void EmbeddingTable::EnableSSD() +void EmbeddingTable::MakeDir(const string& dirName) { - isSSDEnabled_ = true; + if (fileSystemPtr_ == nullptr) { + throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); + } + fileSystemPtr_->CreateDir(dirName); } -void EmbeddingTable::RefreshFreqInfoWithSwap() +void EmbeddingTable::SetCacheManager(CacheManager *cm) { } @@ -201,8 +171,6 @@ TableInfo EmbeddingTable::GetTableInfo() .devVocabSize=devVocabSize, .maxOffset=maxOffset, .keyOffsetMap=keyOffsetMap, - .evictDevPos=evictDevPos, - .evictHostPos=evictHostPos, }; return ti; } @@ -214,4 +182,12 @@ vector EmbeddingTable::GetDeviceOffset() void EmbeddingTable::SetOptimizerInfo(OptimizerInfo& optimizerInfo) { -} \ No newline at end of file +} + +void EmbeddingTable::SetHDTransfer(HDTransfer *hdTransfer) +{ +} + +void EmbeddingTable::SetEmbCache(ock::ctr::EmbCacheManagerPtr embCache) +{ +} diff --git a/src/core/emb_table/embedding_table.h b/src/core/emb_table/embedding_table.h index 0c05a0a0cd3e8990154bc100a1725a7cd5f8c78f..da6a42bee1e9f6eb719ed139e3ff4d7af0d773b1 100644 --- a/src/core/emb_table/embedding_table.h +++ b/src/core/emb_table/embedding_table.h @@ -15,12 +15,14 @@ See the License for the specific language governing permissions and #ifndef MX_REC_EMBEDDING_TABLE_H #define MX_REC_EMBEDDING_TABLE_H +#include #include #include #include #include "utils/common.h" -#include "ssd_cache/cache_manager.h" +#include "l3_storage/cache_manager.h" +#include "file_system/file_system_handler.h" namespace MxRec { @@ -37,21 +39,11 @@ public: */ virtual void Key2Offset(std::vector& keys, int channel); - /** - * DDR模式使用 - */ - virtual void FindOffset(const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channelId); - - virtual std::vector FindOffset(const vector& keys, - size_t batchId, int channelId, - std::vector& swapPos); - /** * 淘汰key, 配合GetEvictedKeys一起使用GetEvictedKeys * EvictKeys执行,通过GetEvictedKeys, GetEvictedKeys拿结果 */ - virtual void EvictKeys(const std::vector& keys); + virtual void EvictKeys(const std::vector& keys); /** * 获取设备侧淘汰的key的偏移或者地址 @@ -73,25 +65,21 @@ public: virtual size_t size() const; - void ClearMissingKeys(); - - virtual const std::vector& GetMissingKeys(); - absl::flat_hash_map GetKeyOffsetMap(); - virtual void SetStartCount(); + void SetFileSystemPtr(const string& savePath); - virtual void ClearLookupAndSwapOffset(); + void UnsetFileSystemPtr(); - virtual void Load(const string& savePath); + virtual void Load(const string& savePath, map>& trainKeySet); virtual void Save(const string& savePath); - size_t GetDevVocabSize(); + void MakeDir(const string& dirName); - size_t GetHostVocabSize(); + virtual void BackUpTrainStatus(); - static void MakeDir(const string& dirName); + virtual void RecoverTrainStatus(); virtual vector GetDeviceOffset(); @@ -101,20 +89,21 @@ public: virtual void SetCacheManager(CacheManager* cacheManager); - void EnableSSD(); + virtual TableInfo GetTableInfo(); - virtual void RefreshFreqInfoWithSwap(); + virtual void SetHDTransfer(HDTransfer *hdTransfer); - virtual TableInfo GetTableInfo(); + virtual void SetEmbCache(ock::ctr::EmbCacheManagerPtr embCache); std::string name; size_t hostVocabSize; size_t devVocabSize; + size_t ssdVocabSize; size_t maxOffset; absl::flat_hash_map keyOffsetMap; + absl::flat_hash_map keyOffsetMapBackUp; std::vector evictDevPos; // 记录HBM内被淘汰的key std::vector evictHostPos; // 记录Host内淘汰列表 - std::mutex mutSave_; // 用于保存时锁住KeyOffsetMap #ifdef NDEBUG protected: @@ -130,7 +119,7 @@ protected: size_t embSize_; size_t extEmbSize_; int seed_; - int64_t capacity_; + std::atomic capacity_{0}; size_t rankId_; size_t rankSize_; vector loadOffset; @@ -138,6 +127,8 @@ protected: std::vector missingKeysHostPos_; // 用于记录当前batch在host上需要换出的偏移 CacheManager* cacheManager_; bool isSSDEnabled_ = false; + + unique_ptr fileSystemPtr_; }; } diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h index 2f7d3b62e5f41b015abc42632386eae25e58f312..5546c691b0605c44cfa6c38d2bb28bac9df2a206 100644 --- a/src/core/file_system/file_system.h +++ b/src/core/file_system/file_system.h @@ -31,13 +31,16 @@ namespace MxRec { virtual size_t GetFileSize(const string& filePath) = 0; virtual ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) = 0; - virtual ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) = 0; + virtual ssize_t Write(const string& filePath, vector>& fileContent, size_t dataSize) = 0; virtual void WriteEmbedding(const string& filePath, const int& embeddingSize, const vector& addressArr, int deviceId) = 0; virtual ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) = 0; virtual ssize_t Read(const string& filePath, vector>& fileContent, int64_t contentOffset, vector offsetArr, const size_t& embeddingSize) = 0; + + // In the dynamic expansion mode, embedding is read from the file + // and transported from the host side to the device side. virtual void ReadEmbedding(const string& filePath, EmbeddingSizeInfo& embedSizeInfo, int64_t firstAddress, int deviceId, vector offsetArr) = 0; diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp index 999f2fa9385d1b33a13bbc87ebcc147c4f33d1c4..45c50f6f43c703f25b426e9fa699d71c03e85c04 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp @@ -28,19 +28,15 @@ using namespace MxRec; void HdfsFileSystem::CreateDir(const string& dirName) { - hdfsFS fs = ConnectHdfs(); int ret = hdfs->CreateDirectory(fs, dirName.c_str()); if (ret == -1) { LOG_DEBUG("Unable to create hdfs directory: {}", dirName); } - hdfs->Disconnect(fs); } vector HdfsFileSystem::ListDir(const string& dirName) { vector dirs; - hdfsFS fs = ConnectHdfs(); - int numEntries = 0; hdfsFileInfo* subDirs = hdfs->ListDirectory(fs, dirName.c_str(), &numEntries); for (int i = 0; i < numEntries; ++i) { @@ -50,17 +46,14 @@ vector HdfsFileSystem::ListDir(const string& dirName) } hdfs->FreeFileInfo(subDirs, numEntries); - hdfs->Disconnect(fs); return dirs; } size_t HdfsFileSystem::GetFileSize(const string& filePath) { - hdfsFS fs = ConnectHdfs(); hdfsFileInfo* fileInfo = hdfs->GetPathInfo(fs, filePath.c_str()); - hdfs->Disconnect(fs); if (fileInfo == nullptr) { - return 0; + throw runtime_error(StringFormat("Error: Unable to get hdfs file info : %s.", filePath.c_str())); } auto fileSize = static_cast(fileInfo->mSize); return fileSize; @@ -68,77 +61,41 @@ size_t HdfsFileSystem::GetFileSize(const string& filePath) ssize_t HdfsFileSystem::Write(const string& filePath, const char* fileContent, size_t dataSize) { - hdfsFS fs = ConnectHdfs(); - hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); if (!file) { - hdfs->Disconnect(fs); - throw runtime_error("Error writing to hdfs file."); + throw runtime_error(StringFormat("Error: Unable to open hdfs file : %s.", filePath.c_str())); } - size_t dataCol = dataSize; - size_t writeSize = 0; - size_t idx = 0; tSize writeBytesNum = 0; - - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - writeSize = oneTimeReadWriteLen; - } else { - writeSize = dataCol; - } - - tSize res = hdfs->Write(fs, file, fileContent + idx, writeSize); - if (res == -1) { - hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - return static_cast(res); - } - dataCol -= writeSize; - idx += writeSize; - writeBytesNum += res; + tSize res = hdfs->Write(fs, file, fileContent, dataSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + return static_cast(res); } + writeBytesNum += res; hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); return static_cast(writeBytesNum); } -ssize_t HdfsFileSystem::Write(const string& filePath, vector fileContent, size_t dataSize) +ssize_t HdfsFileSystem::Write(const string& filePath, vector>& fileContent, size_t dataSize) { - hdfsFS fs = ConnectHdfs(); - hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); if (!file) { - hdfs->Disconnect(fs); - throw runtime_error("Error writing to hdfs file."); + throw runtime_error(StringFormat("Error: Unable to open hdfs file : %s.", filePath.c_str())); } tSize writeBytesNum = 0; size_t loops = fileContent.size(); for (size_t i = 0; i < loops; i++) { - size_t dataCol = dataSize; - size_t writeSize = 0; - size_t idx = 0; - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - writeSize = oneTimeReadWriteLen; - } else { - writeSize = dataCol; - } - tSize res = hdfs->Write(fs, file, fileContent[i] + idx, writeSize); - if (res == -1) { - hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - return static_cast(res); - } - dataCol -= writeSize; - idx += writeSize; - writeBytesNum += res; + tSize res = hdfs->Write(fs, file, fileContent[i].data(), dataSize * sizeof(float)); + if (res == -1) { + hdfs->CloseFile(fs, file); + return static_cast(res); } + writeBytesNum += res; } hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); return static_cast(writeBytesNum); } @@ -151,15 +108,17 @@ ssize_t HdfsFileSystem::Write(const string& filePath, vector fileContent void HdfsFileSystem::WriteEmbedding(const string& filePath, const int& embeddingSize, const vector& addressArr, int deviceId) { - hdfsFS fs = ConnectHdfs(); - hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); if (!file) { - hdfs->Disconnect(fs); - throw runtime_error("Error writing to hdfs file."); + throw runtime_error(StringFormat("Error: Unable to open hdfs file : %s.", filePath.c_str())); } #ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + hdfs->CloseFile(fs, file); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } for (size_t i = 0; i < addressArr.size(); i += embHashNum) { vector row(embeddingSize); @@ -171,67 +130,52 @@ void HdfsFileSystem::WriteEmbedding(const string& filePath, const int& embedding ACL_MEMCPY_DEVICE_TO_HOST); if (ret != ACL_SUCCESS) { hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - throw runtime_error("aclrtMemcpy failed"); + throw runtime_error("Error: Execute aclrtmemcpy from device to host failed."); + } + + tSize res = hdfs->Write(fs, file, row.data(), embeddingSize * sizeof(float)); + if (res == -1) { + hdfs->CloseFile(fs, file); + throw runtime_error(StringFormat("Error: An error occurred while writing file: %s.", filePath.c_str())); } - auto numBytesWritten = hdfs->Write(fs, file, row.data(), embeddingSize * sizeof(float)); - if (numBytesWritten != embeddingSize * sizeof(float)) { + if (res != embeddingSize * sizeof(float)) { hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - throw runtime_error("Error writing to hdfs file."); + throw runtime_error(StringFormat("Error: Expected to write %d bytes, " + "but actually write %d bytes to file %s.", + embeddingSize * sizeof(float), res, filePath.c_str())); } } #endif hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); } ssize_t HdfsFileSystem::Read(const string& filePath, char* fileContent, size_t datasetSize) { - hdfsFS fs = ConnectHdfs(); - hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); if (!file) { - hdfs->Disconnect(fs); - throw runtime_error("open hdfs file failed."); + throw runtime_error(StringFormat("Error: Unable to open hdfs file : %s.", filePath.c_str())); } - size_t dataCol = datasetSize; - size_t idx = 0; - size_t readSize = 0; tSize readBytesNum = 0; - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - readSize = oneTimeReadWriteLen; - } else { - readSize = dataCol; - } - tSize res = hdfs->Read(fs, file, fileContent + idx, readSize); - if (res == -1) { - hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - return static_cast(res); - } - dataCol -= readSize; - idx += readSize; - readBytesNum += res; + LOG_INFO("Start to read file : {}", filePath); + tSize res = hdfs->Read(fs, file, fileContent, datasetSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + return static_cast(res); } + readBytesNum += res; hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); return static_cast(readBytesNum); } ssize_t HdfsFileSystem::Read(const string& filePath, vector>& fileContent, int64_t contentOffset, vector offsetArr, const size_t& embeddingSize) { - hdfsFS fs = ConnectHdfs(); - hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); if (!file) { - hdfs->Disconnect(fs); - throw runtime_error("open hdfs file failed."); + throw runtime_error(StringFormat("Error: Unable to open hdfs file : %s.", filePath.c_str())); } ssize_t readBytesNum = 0; @@ -241,13 +185,15 @@ ssize_t HdfsFileSystem::Read(const string& filePath, vector>& file tSize res = hdfs->Read(fs, file, fileContent[embeddingCount].data() + contentOffset * embeddingSize, embeddingSize * sizeof(float)); - + if (res == -1) { + hdfs->CloseFile(fs, file); + return static_cast(res); + } embeddingCount++; - readBytesNum += embeddingSize * sizeof(float); + readBytesNum += res; } hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); return static_cast(readBytesNum); } @@ -261,41 +207,57 @@ void HdfsFileSystem::ReadEmbedding(const string& filePath, EmbeddingSizeInfo& em int deviceId, vector offsetArr) { #ifndef GTEST - hdfsFS fs = ConnectHdfs(); - hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); if (!file) { - hdfs->Disconnect(fs); - throw runtime_error("open hdfs file failed."); + throw runtime_error(StringFormat("Error: Unable to open hdfs file : %s.", filePath.c_str())); + } + + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); } float* floatPtr = reinterpret_cast(firstAddress); auto i = 0; for (const auto& offset: offsetArr) { vector row(embedSizeInfo.embeddingSize); - hdfs->Seek(fs, file, offset * embedSizeInfo.embeddingSize * sizeof(float)); + int seekRes = hdfs->Seek(fs, file, offset * embedSizeInfo.embeddingSize * sizeof(float)); + if (seekRes == -1) { + hdfs->CloseFile(fs, file); + throw runtime_error(StringFormat("Error: hdfsSeek failed with error. file offset: %d", + offset * embedSizeInfo.embeddingSize * sizeof(float))); + } + tSize res = hdfs->Read(fs, file, row.data(), embedSizeInfo.embeddingSize * sizeof(float)); - try { - aclrtMemcpy(floatPtr + i * embedSizeInfo.extendEmbSize, embedSizeInfo.embeddingSize * sizeof(float), - row.data(), embedSizeInfo.embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); - } catch (std::exception& e) { + if (res == -1) { + hdfs->CloseFile(fs, file); + throw runtime_error(StringFormat("Error: An error occurred while reading file: %s.", filePath.c_str())); + } + if (res != embedSizeInfo.embeddingSize * sizeof(float)) { + hdfs->CloseFile(fs, file); + throw runtime_error(StringFormat("Error: Expected to read %d bytes, " + "but actually read %d bytes from file %s.", + embedSizeInfo.embeddingSize * sizeof(float), res, filePath.c_str())); + } + + aclError ret = aclrtMemcpy(floatPtr + i * embedSizeInfo.extendEmbSize, + embedSizeInfo.embeddingSize * sizeof(float), + row.data(), embedSizeInfo.embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - throw runtime_error(StringFormat("error happen when acl memory copy from host to device: %s", e.what())); + throw runtime_error("Error: Execute aclrtmemcpy from host to device failed."); } i++; } - hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); #endif } hdfsFS HdfsFileSystem::ConnectHdfs() { - hdfsFS fs = hdfs->Connect("default", 0); - if (!fs) { + hdfsFS hdfsClient = hdfs->Connect("default", 0); + if (!hdfsClient) { throw runtime_error("Connect hdfs file system failed."); } - return fs; + return hdfsClient; } \ No newline at end of file diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.h b/src/core/file_system/hdfs_file_system/hdfs_file_system.h index 8d436d3de5e014e4a6cd61f4d292e5affd4b17d0..bf56062f11e0eb9f0f51c94d25956cba600a325d 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.h +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.h @@ -24,18 +24,18 @@ namespace MxRec { class HdfsFileSystem : public FileSystem { public: - HdfsFileSystem() + HdfsFileSystem() {}; + ~HdfsFileSystem() { - hdfs = make_unique(); - }; - ~HdfsFileSystem() override {} + hdfs->Disconnect(fs); + } void CreateDir(const string& dirName) override; vector ListDir(const string& dirName) override; size_t GetFileSize(const string& filePath) override; ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; - ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) override; + ssize_t Write(const string& filePath, vector>& fileContent, size_t dataSize) override; void WriteEmbedding(const string& filePath, const int& embeddingSize, const vector& addressArr, int deviceId) override; @@ -47,7 +47,8 @@ namespace MxRec { hdfsFS ConnectHdfs(); - unique_ptr hdfs; + unique_ptr hdfs = make_unique(); + hdfsFS fs = ConnectHdfs(); }; } diff --git a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h index 0f33934f04adc7e44e194e7b2ba243efed043101..b00913ff2bf57a95b438ff9d37b97fed86afa303 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h +++ b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h @@ -134,20 +134,86 @@ namespace MxRec { return hdfsCloseFile(fs, file); } - tSize Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length) const + tSize Read(hdfsFS fs, hdfsFile file, char* buffer, tSize length) const { if (hdfsRead == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsRead from the libhdfs."); } - return hdfsRead(fs, file, buffer, length); + + tSize unReadLength = length; + tSize readBytes = 0; + + while (unReadLength != 0) { + tSize offset = (length - unReadLength) / sizeof(char); + tSize res = hdfsRead(fs, file, buffer + offset, unReadLength); + if (res == -1) { + return res; + } + unReadLength -= res; + readBytes += res; + } + return readBytes; + } + + tSize Read(hdfsFS fs, hdfsFile file, float* buffer, tSize length) const + { + if (hdfsRead == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsRead from the libhdfs."); + } + + tSize unReadLength = length; + tSize readBytes = 0; + + while (unReadLength != 0) { + tSize offset = (length - unReadLength) / sizeof(float); + tSize res = hdfsRead(fs, file, buffer + offset, unReadLength); + if (res == -1) { + return res; + } + unReadLength -= res; + readBytes += res; + } + return readBytes; } - tSize Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length) const + tSize Write(hdfsFS fs, hdfsFile file, const char* buffer, tSize length) const { if (hdfsWrite == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsWrite from the libhdfs."); } - return hdfsWrite(fs, file, buffer, length); + tSize unWriteLength = length; + tSize writeBytes = 0; + + while (unWriteLength != 0) { + tSize offset = (length - unWriteLength) / sizeof(char); + tSize res = hdfsWrite(fs, file, buffer + offset, unWriteLength); + if (res == -1) { + return res; + } + unWriteLength -= res; + writeBytes += res; + } + return writeBytes; + } + + tSize Write(hdfsFS fs, hdfsFile file, const float* buffer, tSize length) const + { + if (hdfsWrite == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsWrite from the libhdfs."); + } + tSize unWriteLength = length; + tSize writeBytes = 0; + + while (unWriteLength != 0) { + tSize offset = (length - unWriteLength) / sizeof(float); + tSize res = hdfsWrite(fs, file, buffer + offset, unWriteLength); + if (res == -1) { + return res; + } + unWriteLength -= res; + writeBytes += res; + } + return writeBytes; } int Seek(hdfsFS fs, hdfsFile file, tOffset desiredPos) const diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp index 43cd00334372d8b02fa0bd0f4a3ba1d8065d914b..b0b5c76ad1dc07a153bb77e23668770ba63a229d 100644 --- a/src/core/file_system/local_file_system/local_file_system.cpp +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -38,13 +38,13 @@ void LocalFileSystem::CreateDir(const string& dirName) while (getline(input, tmp, '/')) { guard++; if (guard > maxDepth) { - throw runtime_error(StringFormat("create directory {} exceed max depth", dirName.c_str())); + throw runtime_error(StringFormat("create directory %s exceed max depth", dirName.c_str())); } ss << tmp << '/'; int ret = mkdir(ss.str().c_str(), dirMode); if (ret != 0 && errno != EEXIST) { LOG_ERROR("Unable to create directory: {} ret:{} error info: {}", dirName, ret, strerror(errno)); - throw runtime_error(StringFormat("create directory {} failed: {}", dirName.c_str(), strerror(errno))); + throw runtime_error(StringFormat("create directory %s failed: %s", dirName.c_str(), strerror(errno))); } } } @@ -112,44 +112,42 @@ ssize_t LocalFileSystem::Write(const string& filePath, const char* fileContent, return writeBytesNum; } -ssize_t LocalFileSystem::Write(const string& filePath, vector fileContent, size_t dataSize) +ssize_t LocalFileSystem::Write(const string& filePath, vector>& fileContent, size_t dataSize) { int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); if (fd == -1) { throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); } - buffer.reserve(BUFFER_SIZE); - BufferQueue queue; - ssize_t writeBytesNum = 0; - std::thread writer(&LocalFileSystem::WriterFn, this, std::ref(queue), fd, std::ref(writeBytesNum)); - - size_t loops = fileContent.size(); - for (size_t i = 0; i < loops; i++) { - size_t idx = 0; - size_t writeSize = 0; - size_t dataCol = dataSize; - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - writeSize = oneTimeReadWriteLen; - } else { - writeSize = dataCol; - } - FillToBuffer(queue, reinterpret_cast(fileContent[i]) + idx, writeSize); - dataCol -= writeSize; - idx += writeSize; - } + vector flattenContent; + for (auto& vec : fileContent) { + flattenContent.insert(flattenContent.cend(), vec.cbegin(), vec.cend()); } - // After all data has been processed, check if there is any data left in the buffer - if (!buffer.empty()) { - queue.Push(std::move(buffer)); - buffer.clear(); + size_t writeBytesRemain = flattenContent.size() * sizeof(float); + size_t writeSize = 0; + size_t idx = 0; + ssize_t writeBytesNum = 0; + auto dumpPtr = reinterpret_cast(flattenContent.data()); + + while (writeBytesRemain != 0) { + if (writeBytesRemain > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = writeBytesRemain; + } + ssize_t res = write(fd, dumpPtr + idx, writeSize); + if (res == -1) { + close(fd); + return res; + } + writeBytesRemain -= res; + idx += res; + writeBytesNum += res; } - queue.Push(std::vector()); - writer.join(); close(fd); + return writeBytesNum; } @@ -168,6 +166,12 @@ void LocalFileSystem::WriteEmbedding(const string& filePath, const int& embeddin } #ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + close(fd); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + for (size_t i = 0; i < addressArr.size(); i += keyAddrElem) { vector row(embeddingSize); int64_t address = addressArr.at(i); @@ -271,6 +275,10 @@ void LocalFileSystem::ReadEmbedding(const string& filePath, EmbeddingSizeInfo& e if (fp == nullptr) { throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); } + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } float* floatPtr = reinterpret_cast(firstAddress); auto i = 0; diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h index d137f158d383d314588409576ee70f90f38e34a2..9b09f34d5516b4b8358640c0255c1efcd42528d1 100644 --- a/src/core/file_system/local_file_system/local_file_system.h +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -33,7 +33,7 @@ namespace MxRec { size_t GetFileSize(const string& filePath) override; ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; - ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) override; + ssize_t Write(const string& filePath, vector>& fileContent, size_t dataSize) override; void WriteEmbedding(const string& filePath, const int& embeddingSize, const vector& addressArr, int deviceId) override; @@ -46,8 +46,6 @@ namespace MxRec { void WriterFn(BufferQueue& queue, int fd, ssize_t& writerBytesNum); void FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize); void CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const; - void HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, - vector>& dst, size_t cnt) const; private: const mode_t dirMode; diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 7bd083ab0a3da81b6ced4f65262269cb3c8ef5c6..8fc2a28258a0ba3668b12e6dc7308324f9c97fbe 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -50,7 +50,14 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) CreateChannel(localRankId, embInfo.name, i); } // 创建acltdtDataset类型的数据,对等一个Vector。同步接口。 - aclDatasets[embInfo.name] = acltdtCreateDataset(); + for (int j = 0; j < EMBEDDING_THREAD_NUM; j++) { + acltdtDataset* dataset = acltdtCreateDataset(); + if (dataset == nullptr) { + LOG_ERROR("create acltdtDataset failed, table:{}, threadId:{}", embName, j); + throw runtime_error("create acltdtDataset failed"); + } + aclDatasets[embInfo.name][j] = dataset; + } } running = true; LOG(INFO) << "hd_transfer init"; @@ -71,9 +78,11 @@ void HDTransfer::Destroy() } LOG_INFO(HD + "destroy channel:{}", c.first); } - for (auto& d: aclDatasets) { - if (acltdtDestroyDataset(d.second) != ACL_ERROR_NONE) { - throw runtime_error("Acl destroy tensor dataset failed."); + for (auto& datasetMap: aclDatasets) { + for (auto &d: datasetMap.second) { + if (acltdtDestroyDataset(d.second) != ACL_ERROR_NONE) { + throw runtime_error("Acl destroy tensor dataset failed."); + } } } aclFinalize(); @@ -90,20 +99,30 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName int channelSize = GlobalEnv::hdChannelSize; LOG_INFO("user config all2all restore lookup channel size:{}", channelSize); for (int c = static_cast(TransferChannel::D2H); c != static_cast(TransferChannel::INVALID); c++) { + if ((c == static_cast(TransferChannel::SWAP) || c == static_cast(TransferChannel::D2H) || + c == static_cast(TransferChannel::H2D)) && channelNum == EVAL_CHANNEL_ID) { + continue; + } + auto channel = static_cast(c); - string sendName = StringFormat( - "%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelNum - ); + std::string sendName; + if (c == static_cast(TransferChannel::SWAP) || c == static_cast(TransferChannel::D2H) || + c == static_cast(TransferChannel::H2D)) { + sendName = StringFormat("%s_%s_all", embName.c_str(), TransferChannel2Str(channel).c_str()); + } else { + sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelNum); + } if (TransferChannel2Str(channel) == "all2all" || TransferChannel2Str(channel) == "restore" || TransferChannel2Str(channel) == "lookup" || TransferChannel2Str(channel) == "restore_second" || TransferChannel2Str(channel) == "uniquekeys" || - TransferChannel2Str(channel) == "evict" /* for noDDR */ + TransferChannel2Str(channel) == "evict" || + TransferChannel2Str(channel) == "swap" ) { - transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), channelSize); + transferChannels[sendName] = TDT_CREATE_CHANNEL(localRankId, sendName.c_str(), channelSize); } else { - transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), PING_PONG_SIZE); + transferChannels[sendName] = TDT_CREATE_CHANNEL(localRankId, sendName.c_str(), PING_PONG_SIZE); } LOG_INFO("create channel:{} {}", sendName, static_cast(transferChannels[sendName])); } @@ -128,10 +147,16 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in for (auto& t: tensors) { sizes.push_back(t.NumElements()); } - string sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - LOG_INFO(HD + "hd transfer send {}, send count is {}, size list:{}", - sendName, sizes.size(), VectorToString(sizes)); + string sendName; + if (channel == TransferChannel::SWAP || channel == TransferChannel::D2H || channel == TransferChannel::H2D) { + sendName = StringFormat("%s_%s_all", embName.c_str(), TransferChannel2Str(channel).c_str()); + } else { + sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + } + + LOG_INFO(HD + "hd transfer send:{}, batchId:{}, send count:{}, size list:{}", + sendName, batchId, sizes.size(), VectorToString(sizes)); if (sizes.size() == 0) { LOG_WARN("tensors num can not be zero"); @@ -171,9 +196,15 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in vector HDTransfer::Recv(TransferChannel channel, int channelId, const string& embName) { EASY_FUNCTION() + vector tensors; #ifndef GTEST - std::vector tensors; - string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + string recvName; + if (channel == TransferChannel::SWAP || channel == TransferChannel::D2H || channel == TransferChannel::H2D) { + recvName = StringFormat("%s_%s_all", embName.c_str(), TransferChannel2Str(channel).c_str()); + } else { + recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + } + LOG_DEBUG("hd transfer try recv:{}", recvName); TimeCost tc = TimeCost(); tensorflow::Status status = tensorflow::RecvTensorByAcl(transferChannels[recvName], tensors); @@ -190,8 +221,8 @@ vector HDTransfer::Recv(TransferChannel channel, int channel sizes.push_back(t.NumElements()); } LOG_INFO("hd transfer recv:{}, size:{} cost:{}ms", recvName, VectorToString(sizes), tc.ElapsedMS()); - return tensors; #endif + return tensors; } /// 接收从device发送过来的数据(D2H), updateEmbV2函数使用;使用原生的aclTDT接口 @@ -199,27 +230,36 @@ vector HDTransfer::Recv(TransferChannel channel, int channel /// \param channelId 通道索引(训练/推理) /// \param embName 表名 /// \return -size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) +size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName, + int embeddingThreadId, int batchId) { EASY_FUNCTION() + size_t ret = 0; #ifndef GTEST - std::vector tensors; - string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - LOG_DEBUG("hd transfer try recv:{}", recvName); + string recvName; + if (channel == TransferChannel::SWAP || channel == TransferChannel::D2H || channel == TransferChannel::H2D) { + recvName = StringFormat("%s_%s_all", embName.c_str(), TransferChannel2Str(channel).c_str()); + } else { + recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + } + + LOG_DEBUG("hd transfer try recv:{}, batchId:{}", recvName, batchId); TimeCost tc = TimeCost(); - if (aclDatasets[embName] == nullptr) { + if (aclDatasets[embName][embeddingThreadId] == nullptr) { throw runtime_error(StringFormat("Failed recv:%s.", recvName.c_str()).c_str()); } - auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], GlobalEnv::aclTimeout); + auto aclStatus = acltdtReceiveTensor( + transferChannels[recvName], aclDatasets[embName][embeddingThreadId], GlobalEnv::aclTimeout); if (!running) { return 0; } if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); } - LOG_INFO("hd transfer recv:{} cost:{}ms", recvName, tc.ElapsedMS()); - return acltdtGetDatasetSize(aclDatasets[embName]); + LOG_INFO("hd transfer recv:{}, batchId:{}, cost:{}ms", recvName, batchId, tc.ElapsedMS()); + ret = acltdtGetDatasetSize(aclDatasets[embName][embeddingThreadId]); #endif + return ret; } std::unordered_map HDTransfer::GetTransChannel() diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index 0ff29e1baeb48e46e1df16f6c5d6c71a9eacd16f..58c480670d2e8fe5270466a159fa18deb81ae442 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -24,8 +24,8 @@ See the License for the specific language governing permissions and #include "utils/common.h" #include "utils/config.h" -#ifndef tdtCreateChannel -#define tdtCreateChannel acltdtCreateChannelWithCapacity +#ifndef TDT_CREATE_CHANNEL +#define TDT_CREATE_CHANNEL acltdtCreateChannelWithCapacity #endif namespace MxRec { @@ -45,6 +45,8 @@ namespace MxRec { EVICT, H2D, SWAP, + SAVE_D2H, + SAVE_H2D, INVALID }; @@ -69,6 +71,10 @@ namespace MxRec { return "h2d"; case TransferChannel::SWAP: return "swap"; + case TransferChannel::SAVE_D2H: + return "save_d2h"; + case TransferChannel::SAVE_H2D: + return "save_h2d"; default: throw std::invalid_argument("Invalid TransferChannel"); } @@ -76,7 +82,7 @@ namespace MxRec { class HDTransfer { public: - std::unordered_map aclDatasets; + std::unordered_map> aclDatasets; HDTransfer() = default; @@ -87,7 +93,8 @@ namespace MxRec { vector Recv(TransferChannel channel, int channelId, const string& embName); - size_t RecvAcl(TransferChannel channel, int channelId, const string& embName); + size_t RecvAcl(TransferChannel channel, int channelId, const string& embName, + int embeddingThreadId, int batchId); void Destroy(); diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp deleted file mode 100644 index ce0e0a78e133f5a36bc4d695184a9bc9cbc342b7..0000000000000000000000000000000000000000 --- a/src/core/host_emb/host_emb.cpp +++ /dev/null @@ -1,278 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include "host_emb.h" -#include -#include "hd_transfer/hd_transfer.h" -#include "checkpoint/checkpoint.h" -#include "initializer/initializer.h" -#include "utils/time_cost.h" - -using namespace MxRec; -using namespace std; -using namespace chrono; - -/// 初始化host emb -/// \param embInfos 表信息列表 -/// \param seed 随机种子 -/// \return -void HostEmb::Initialize(const vector& embInfos, int seed) -{ - for (const auto& embInfo: embInfos) { - HostEmbTable hostEmb; - hostEmb.hostEmbInfo = embInfo; - EmbDataGenerator(embInfo.initializeInfos, seed, static_cast(embInfo.hostVocabSize), - embInfo.extEmbeddingSize, hostEmb.embData); - hostEmbs[embInfo.name] = move(hostEmb); - LOG_INFO(HOSTEMB + "HostEmb Initialize End"); - } -} - -/// 根据指定的初始化器对emb进行初始化 -/// \param initializeInfos emb初始化信息列表 -/// \param seed 随机种子 -/// \param vocabSize host表大小 -/// \param embeddingSize emb维度 -/// \param embData emb数据 -void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, - int embeddingSize, vector> &embData) const -{ -#ifndef GTEST - LOG_INFO(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); - embData.clear(); - embData.resize(vocabSize, vector(embeddingSize)); - - for (auto initializeInfo: initializeInfos) { - LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); - for (int i = 0; i < vocabSize; i++) { - initializeInfo.initializer->GenerateData(embData.at(i).data(), embeddingSize); - } - } - LOG_INFO(HOSTEMB + "GenerateEmbData End, seed:{}", seed); -#endif -} - -/// 停止用于异步更新D2H emb的线程 -/// \param channelId 通道索引(训练/推理) -void HostEmb::Join(int channelId) -{ - TimeCost tc = TimeCost(); - switch (channelId) { - case TRAIN_CHANNEL_ID: - LOG_DEBUG(HOSTEMB + "start join, channelId:{}, procThreadsForTrain num:{}", - channelId, procThreadsForTrain.size()); - for (auto& t: procThreadsForTrain) { - t->join(); - } - procThreadsForTrain.clear(); - LOG_DEBUG(HOSTEMB + "end join, channelId:{}, cost:{}ms", channelId, tc.ElapsedMS()); - break; - case EVAL_CHANNEL_ID: - LOG_DEBUG(HOSTEMB + "start join, channelId:{}, procThreadsForEval num:{}", - channelId, procThreadsForEval.size()); - for (auto& t: procThreadsForEval) { - t->join(); - } - procThreadsForEval.clear(); - LOG_DEBUG(HOSTEMB + "end join, channelId:{}, cost:{}ms", channelId, tc.ElapsedMS()); - break; - default: - throw invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } -} - -#ifndef GTEST -/// 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 -/// missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 -/// \param missingKeysHostPos 当前batch在host上需要换出的偏移 -/// \param channelId 通道索引(训练/推理) -/// \param embName 表名 -void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) -{ - LOG_INFO(HOSTEMB + "UpdateEmb, channelId:{}, embName:{}", channelId, embName); - EASY_FUNCTION(profiler::colors::Purple); - TimeCost tc = TimeCost(); - auto hdTransfer = Singleton::GetInstance(); - TransferChannel transferName = TransferChannel::D2H; - LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); - const auto tensors = hdTransfer->Recv(transferName, channelId, embName); - if (tensors.empty()) { - LOG_WARN(HOSTEMB + "recv empty data"); - return; - } - const Tensor& d2hEmb = tensors[0]; - EASY_BLOCK("Update") - const float* tensorPtr = d2hEmb.flat().data(); - auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; - auto& embData = hostEmbs[embName].embData; - - LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}, " - "embData.size = {} {}", embName, missingKeysHostPos.size(), embeddingSize, embData.size(), tensorPtr); - -#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ - shared(missingKeysHostPos, tensorPtr, embData, embeddingSize) - for (size_t i = 0; i < missingKeysHostPos.size(); i++) { - auto& dst = embData[missingKeysHostPos[i]]; -#pragma omp simd - for (int j = 0; j < embeddingSize; j++) { - dst[j] = tensorPtr[j + embeddingSize * i]; - } - } - LOG_INFO(HOSTEMB + "update emb end cost: {}ms", tc.ElapsedMS()); - EASY_END_BLOCK -} - -/// 用从device获取的数据更新host的emb(使用aclTDT原生接口) -/// \param missingKeysHostPos 当前batch在host上需要换出的偏移 -/// \param channelId 通道索引(训练/推理) -/// \param embName 表名 -void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) -{ - LOG_INFO(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); - EASY_FUNCTION(profiler::colors::Purple) - auto updateThread = - [this, missingKeysHostPos, channelId, embName] { - auto hdTransfer = Singleton::GetInstance(); - TransferChannel transferName = TransferChannel::D2H; - LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); - auto size = hdTransfer->RecvAcl(transferName, channelId, embName); - if (size == 0) { - LOG_WARN(HOSTEMB + "recv empty data"); - return; - } - TimeCost tc = TimeCost(); - - EASY_BLOCK("Update") - auto& embData = hostEmbs[embName].embData; - auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; - auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[embName], 0); - if (aclData == nullptr) { - throw runtime_error("Acl get tensor data from dataset failed."); - } - float* ptr = static_cast(acltdtGetDataAddrFromItem(aclData)); - if (ptr == nullptr || missingKeysHostPos.size() == 0) { - return; - } - size_t elementSize = acltdtGetDataSizeFromItem(aclData); - size_t dimNum = acltdtGetDimNumFromItem(aclData); - LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," - " embData.size = {}, RecvAcl = {}, elementSize = {}, dimNum = {}", - embName, missingKeysHostPos.size(), embeddingSize, embData.size(), size, elementSize, dimNum); -#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) - for (size_t j = 0; j < missingKeysHostPos.size(); j++) { - auto& dst = embData[missingKeysHostPos[j]]; -#pragma omp simd - for (int k = 0; k < embeddingSize; k++) { - dst[k] = ptr[k + embeddingSize * j]; - } - } - LOG_INFO(HOSTEMB + "update emb end cost: {}ms", tc.ElapsedMS()); - }; - - switch (channelId) { - case TRAIN_CHANNEL_ID: - procThreadsForTrain.emplace_back(make_unique(updateThread)); - break; - case EVAL_CHANNEL_ID: - procThreadsForEval.emplace_back(make_unique(updateThread)); - break; - default: - throw invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } -} - -/// 查找host侧需要发送给device的emb数据。 -/// \param missingKeysHostPos 当前batch在host上需要换出的偏移 -/// \param embName -/// \param h2dEmbOut -void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& embName, - vector& h2dEmbOut) -{ - EASY_FUNCTION() - TimeCost tc = TimeCost(); - const auto& emb = hostEmbs[embName]; - const int embeddingSize = emb.hostEmbInfo.extEmbeddingSize; - h2dEmbOut.emplace_back(Tensor(tensorflow::DT_FLOAT, { - int(missingKeysHostPos.size()), embeddingSize - })); - auto& tmpTensor = h2dEmbOut.back(); - auto tmpData = tmpTensor.flat(); -#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(missingKeysHostPos, emb, tmpData) - for (size_t i = 0; i < missingKeysHostPos.size(); ++i) { - const auto& src = emb.embData[missingKeysHostPos[i]]; -#pragma omp simd - for (int j = 0; j < embeddingSize; j++) { - tmpData(j + i * embeddingSize) = src[j]; - } - } - LOG_INFO("GetH2DEmb end, missingKeys count:{} cost:{}ms", missingKeysHostPos.size(), tc.ElapsedMS()); -} - -/// 获取hostEmbs的指针 -/// \return -auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* -{ - return &hostEmbs; -} - -/// 对指定offset的emb进行初始化 -/// \param initializeInfos emb初始化信息列表 -/// \param embData emb数据 -/// \param offset 偏移列表 -void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, - const vector& offset) const -{ - for (auto initializeInfo: initializeInfos) { - LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); - for (size_t i = 0; i < offset.size(); ++i) { - initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), - static_cast(embData[0].size())); - } - } -} - -void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, - const vector& offset) const -{ - for (auto initializeInfo: initializeInfos) { - LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); - for (size_t i = 0; i < offset.size(); ++i) { - initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), - static_cast(embData[0].size())); - } - } -} -#endif - -/// 利用initializer初始化emb淘汰的位置 -/// \param embName 表名 -/// \param offset 淘汰的偏移列表 -void HostEmb::EvictInitEmb(const string& embName, const vector& offset) -{ -#ifndef GTEST - auto& hostEmb = GetEmb(embName); - EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); - LOG_INFO(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); -#endif -} - -void HostEmb::EvictInitEmb(const string& embName, const vector& offset) -{ -#ifndef GTEST - auto& hostEmb = GetEmb(embName); - EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); - LOG_INFO(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); -#endif -} \ No newline at end of file diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h deleted file mode 100644 index a9ff378649d658abb16d57fc2511587aafc56453..0000000000000000000000000000000000000000 --- a/src/core/host_emb/host_emb.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#ifndef MX_REC_HOSTEMB_H -#define MX_REC_HOSTEMB_H - -#include -#include -#include -#include -#include "absl/container/flat_hash_map.h" -#include "utils/common.h" -#include "utils/singleton.h" -#include "tensorflow/core/framework/tensor.h" - -namespace MxRec { - using namespace std; - using namespace tensorflow; - - class HostEmb { - public: - HostEmb() = default; - - ~HostEmb() - {}; - - void Initialize(const vector& embInfos, int seed); - - void Join(int channelId); - - void UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName); - - void UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName); - - void GetH2DEmb(const vector& missingKeysHostPos, const string& embName, - vector& h2dEmbOut); - auto GetHostEmbs() -> absl::flat_hash_map*; - - void EvictInitEmb(const string& embName, const vector& offset); - - void EvictInitEmb(const string& embName, const vector& offset); - - HostEmbTable& GetEmb(const string& embName) - { - return hostEmbs.at(embName); - } - - GTEST_PRIVATE: - absl::flat_hash_map hostEmbs; - - std::vector> procThreadsForTrain; - std::vector> procThreadsForEval; - - void EmbDataGenerator(const vector& initializeInfos, int seed, int vocabSize, int embeddingSize, - vector>& embData) const; - void EmbPartGenerator(const vector &initializeInfos, vector> &embData, - const vector& offset) const; - - void EmbPartGenerator(const vector &initializeInfos, vector> &embData, - const vector& offset) const; - }; -} - -#endif // MX_REC_HOSTEMB_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 894dc230dd4012b99a5bd0188be17704b6d9ff3d..4801f95b14df5ef31add10146e87df7bec3e7ce8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -15,27 +15,27 @@ See the License for the specific language governing permissions and #include "hybrid_mgmt.h" +#include + #include +#include #include -#include #include #include +#include "checkpoint/checkpoint.h" +#include "emb_table/embedding_mgmt.h" #include "hd_transfer/hd_transfer.h" #include "hybrid_mgmt/hybrid_mgmt_block.h" -#include "utils/time_cost.h" -#include "utils/logger.h" -#include "utils/common.h" -#include "checkpoint/checkpoint.h" -#include "key_process/key_process.h" #include "key_process/feature_admit_and_evict.h" -#include "emb_table/embedding_mgmt.h" -#include "emb_table/embedding_ddr.h" - +#include "key_process/key_process.h" +#include "utils/common.h" +#include "utils/logger.h" +#include "utils/time_cost.h" using namespace MxRec; using namespace std; - +using namespace ock::ctr; /// Openmpi通信域进程数设置、计算所有表host特征数量总数、设置训练模式(HBM/DDR) /// \param rankInfo @@ -48,17 +48,17 @@ void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfo // 计算训练任务涉及的所有表在DDR中需要分配的key数量 size_t totHostVocabSize = 0; - size_t totalSsdVocabSize = 0; + size_t totalL3StorageVocabSize = 0; for (const auto& emb : embInfos) { totHostVocabSize += emb.hostVocabSize; - totalSsdVocabSize += emb.ssdVocabSize; + totalL3StorageVocabSize += emb.ssdVocabSize; } // 根据DDR的key数量,配置存储模式HBM/DDR if (totHostVocabSize != 0) { rankInfo.isDDR = true; } - if (totalSsdVocabSize != 0) { + if (totalL3StorageVocabSize != 0) { rankInfo.isSSDEnabled = true; } #endif @@ -89,12 +89,17 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, return true; } + // create factory for fastUnique and embeddingCache + int result = ock::ctr::Factory::Create(factory); + if (result != 0) { + throw runtime_error(Logger::Format("create fast factory failed, error code:{}", result)); + } + InitRankInfo(rankInfo, embInfos); - EmbeddingMgmt::Instance()->Init(rankInfo, embInfos, thresholdValues, seed); GlogConfig::gStatOn = GlobalEnv::statOn; - LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", - rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); + LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", rankInfo.localRankSize, + rankInfo.localRankId, rankInfo.rankId); mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; @@ -110,125 +115,44 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, KEY_PROCESS_INSTANCE->Initialize(rankInfo, embInfos, thresholdValues, seed); isRunning = true; + isL3StorageEnabled = rankInfo.isSSDEnabled; + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos, seed); - // DDR模式,初始化hashmap和host emb if (rankInfo.isDDR) { - hostEmbs = Singleton::GetInstance(); - hostHashMaps = make_unique(); - hostEmbs->Initialize(embInfos, seed); - hostHashMaps->Init(rankInfo, embInfos, ifLoad); + InitEmbeddingCache(embInfos); } - // 非断点续训模式,启动数据传输 - isSSDEnabled = rankInfo.isSSDEnabled; - if (isSSDEnabled) { + if (isL3StorageEnabled) { cacheManager = Singleton::GetInstance(); - cacheManager->Init(hostEmbs, mgmtEmbInfo); - hostHashMaps->isSSDEnabled = this->isSSDEnabled; - hostHashMaps->cacheManager = this->cacheManager; - // 启用SSD时,EmbeddingDDR依赖cacheManager - EmbeddingMgmt::Instance()->EnableSSD(); - EmbeddingMgmt::Instance()->SetCacheManagerForEmbTable(this->cacheManager); + // 用户可实现L3Storage接口替换SSDEngine以对接外部存储服务 + auto ssdEngine = std::make_shared(); + cacheManager->Init(embCache, mgmtEmbInfo, ssdEngine); + EmbeddingMgmt::Instance()->SetCacheManagerForEmbTable(cacheManager); } isLoad = ifLoad; if (!isLoad) { Start(); } - for (const auto& info: embInfos) { - LOG_INFO(MGMT + "emb[{}] vocab size {}+{} sc:{}", - info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); + for (const auto& info : embInfos) { + LOG_INFO(MGMT + "table:{}, vocab size dev+host:{}+{}, send count:{}", info.name, info.devVocabSize, + info.hostVocabSize, info.sendCount); } - LOG_INFO(MGMT + "end initialize, isDDR:{}, maxStep:[{}, {}], rank:{}", rankInfo.isDDR, - rankInfo.ctrlSteps.at(TRAIN_CHANNEL_ID), rankInfo.ctrlSteps.at(EVAL_CHANNEL_ID), rankInfo.rankId); + LOG_INFO(MGMT + "end initialize, rankId:{}, isDDR:{}, " + "step[train_interval, eval_interval, save_interval, max_train_step]:[{}, {}, {}, {}]", + rankInfo.rankId, rankInfo.isDDR, rankInfo.ctrlSteps.at(TRAIN_CHANNEL_ID), + rankInfo.ctrlSteps.at(EVAL_CHANNEL_ID), rankInfo.ctrlSteps.at(SAVE_STEP_INDEX), + rankInfo.ctrlSteps.at(MAX_TRAIN_STEP_INDEX)); #endif isInitialized = true; return true; } -// 比较hostHashMap和cacheManager的数据是否一致 -void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) -{ - if (Logger::GetLevel() != Logger::TRACE) { - return; - } - auto& embHashMaps = saveData.embHashMaps; - auto& ddrKeyFreqMap = saveData.ddrKeyFreqMaps; - for (auto& it : embHashMaps) { - string embTableName = it.first; - auto& hostMap = EmbeddingMgmt::Instance()->GetTable(embTableName)->keyOffsetMap; - auto& devSize = it.second.devVocabSize; - auto& lfu = ddrKeyFreqMap[embTableName]; - size_t tableKeyInDdr = 0; - for (const auto& item : hostMap) { - if (item.second < devSize) { - continue; - } - ++tableKeyInDdr; - auto cuKey = item.first; - if (lfu.find(cuKey) == lfu.end()) { - LOG_ERROR("save step error, ddr key:{}, not exist in lfu, hostHashMap offset:", - cuKey, item.second); - } - } - LOG_INFO("save step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", - embTableName, tableKeyInDdr, lfu.size()); - } -} - -/// 保存CacheManager时恢复数据(与恢复hostHashMap类似,仅恢复保存数据,不修改源数据) -/// \param saveData 保存数据 -void HybridMgmt::RestoreFreq4Save(CkptData& saveData) const -{ - // 仅在差异1步时执行恢复操作 - int checkResult = hybridMgmtBlock->CheckSaveEmbMapValid(); - if (checkResult != 1) { - return; - } - auto& ddrKeyFreqMaps = saveData.ddrKeyFreqMaps; - auto& excludeDDRKeyFreqMaps = saveData.excludeDDRKeyFreqMaps; - - for (const auto& it : saveData.embHashMaps) { - auto& embTableName = it.first; - auto& embHashMap = it.second; - vector hbm2DdrKeys; - vector ddr2HbmKeys; - LOG_INFO("restore freq info for save step, table:{}, embHashMap.oldSwap size:{}", - embTableName, embHashMap.oldSwap.size()); - LOG_INFO("before, ddr key table size:{}, exclude ddr key table size:{}", - ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); - for (const auto& swapKeys : embHashMap.oldSwap) { - hbm2DdrKeys.emplace_back(swapKeys.second); - ddr2HbmKeys.emplace_back(swapKeys.first); - } - int hbm2DdrKeysNotInExcludeMapCount = 0; - int ddr2HbmKeysNotInDDRMapCount = 0; - for (auto& key : hbm2DdrKeys) { - if (excludeDDRKeyFreqMaps[embTableName].find(key) == excludeDDRKeyFreqMaps[embTableName].end()) { - ++hbm2DdrKeysNotInExcludeMapCount; - } - ddrKeyFreqMaps[embTableName][key] = excludeDDRKeyFreqMaps[embTableName][key]; - excludeDDRKeyFreqMaps[embTableName].erase(key); - } - for (auto& key : ddr2HbmKeys) { - if (ddrKeyFreqMaps[embTableName].find(key) == ddrKeyFreqMaps[embTableName].end()) { - ++ddr2HbmKeysNotInDDRMapCount; - } - excludeDDRKeyFreqMaps[embTableName][key] = ddrKeyFreqMaps[embTableName][key]; - ddrKeyFreqMaps[embTableName].erase(key); - } - LOG_INFO("hbm2DdrKeysNotInExcludeMapCount:{}, ddr2HbmKeysNotInDDRMapCount:{}", - hbm2DdrKeysNotInExcludeMapCount, ddr2HbmKeysNotInDDRMapCount); - LOG_INFO("after, ddr key table size:{}, exclude ddr key table size:{}", - ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); - } -} - /// 保存模型 /// \param savePath 保存路径 /// \return -bool HybridMgmt::Save(const string savePath) +void HybridMgmt::Save(const string& savePath) { #ifndef GTEST if (!isInitialized) { @@ -242,22 +166,17 @@ bool HybridMgmt::Save(const string savePath) Checkpoint saveCkpt; saveData.keyCountMap = KEY_PROCESS_INSTANCE->GetKeyCountMap(); - EmbeddingMgmt::Instance()->LockSave(); // acquire lock here to prevent HybridMgmt modify keyOffsetMap EmbeddingMgmt::Instance()->Save(savePath); - offsetMapToSend = EmbeddingMgmt::Instance()->GetDeviceOffsets(); + if (!mgmtRankInfo.isDDR) { + // hbm模式只保存必要的offset对应的内容 + offsetMapToSend = EmbeddingMgmt::Instance()->GetDeviceOffsets(); + } - if (isSSDEnabled) { - LOG_DEBUG(MGMT + "Start host side save: ssd mode hashmap"); - for (auto& it : cacheManager->ddrKeyFreqMap) { - saveData.ddrKeyFreqMaps[it.first] = it.second.GetFreqTable(); - } - saveData.excludeDDRKeyFreqMaps = cacheManager->excludeDDRKeyCountMap; - RestoreFreq4Save(saveData); - AddCacheManagerTraceLog(saveData); + if (isL3StorageEnabled) { + LOG_DEBUG(MGMT + "start save L3Storage data"); auto step = GetStepFromPath(savePath); - cacheManager->SaveSSDEngine(step); + cacheManager->Save(step); } - EmbeddingMgmt::Instance()->UnLockSave(); // 保存特征准入淘汰相关的数据 FeatureAdmitAndEvict& featAdmitNEvict = KEY_PROCESS_INSTANCE->GetFeatAdmitAndEvict(); @@ -272,14 +191,15 @@ bool HybridMgmt::Save(const string savePath) saveCkpt.SaveModel(savePath, saveData, mgmtRankInfo, mgmtEmbInfo); // 数据处理线程释放锁 KEY_PROCESS_INSTANCE->LoadSaveUnlock(); + hybridMgmtBlock->FinishSave(); + cvCheckSave.notify_all(); #endif - return true; } /// 加载模型 /// \param loadPath /// \return -bool HybridMgmt::Load(const string& loadPath) +bool HybridMgmt::Load(const string& loadPath, vector warmStartTables) { #ifndef GTEST if (!isInitialized) { @@ -295,21 +215,26 @@ bool HybridMgmt::Load(const string& loadPath) Checkpoint loadCkpt; vector loadFeatures; SetFeatureTypeForLoad(loadFeatures); + BackUpTrainStatus(); + + if (warmStartTables.size() == 0) { + EmbeddingMgmt::Instance()->Load(loadPath, trainKeysSet); + } else { + for (auto& tableName : warmStartTables) { + EmbeddingMgmt::Instance()->Load(tableName, loadPath, trainKeysSet); + } + } - EmbeddingMgmt::Instance()->Load(loadPath); - loadOffsetToSend = EmbeddingMgmt::Instance()->GetLoadOffsets(); + if (!mgmtRankInfo.isDDR) { + // hbm模式只保存必要的offset对应的内容 + loadOffsetToSend = EmbeddingMgmt::Instance()->GetLoadOffsets(); + } // 执行加载操作 loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); KEY_PROCESS_INSTANCE->LoadKeyCountMap(loadData.keyCountMap); - if (mgmtRankInfo.isDDR) { - // DDR模式 将加载的hash map进行赋值 - LOG_DEBUG(MGMT + "Start host side load: ddr mode hashmap"); - auto GetEmbHashMaps = EmbeddingMgmt::Instance()->GetEmbHashMaps(); - LOG_DEBUG(MGMT + "over over Start host side load: ddr mode hashmap"); - hostHashMaps->LoadHashMap(GetEmbHashMaps); - } else { + if (!mgmtRankInfo.isDDR) { // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 LOG_DEBUG(MGMT + "Start host side load: no ddr mode hashmap"); auto keyOffsetMap = EmbeddingMgmt::Instance()->GetKeyOffsetMap(); @@ -326,15 +251,14 @@ bool HybridMgmt::Load(const string& loadPath) featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } - if (isSSDEnabled) { - LOG_DEBUG(MGMT + "Start host side load: ssd key freq map"); + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[TRAIN_CHANNEL_ID]; + if (isL3StorageEnabled) { + LOG_DEBUG(MGMT + "Start host side load: L3Storage key freq map"); auto step = GetStepFromPath(loadPath); - cacheManager->Load(loadData.ddrKeyFreqMaps, loadData.excludeDDRKeyFreqMaps, - step, mgmtRankInfo.rankSize, mgmtRankInfo.rankId); - for (auto info: mgmtEmbInfo) { - auto tb = EmbeddingMgmt::Instance()->GetTable(info.name); - auto tbCast = reinterpret_pointer_cast(tb); - tbCast->RefreshFreqInfoAfterLoad(); + // When in load and train mode or predict mode, SSD needs to actually execute loading + // When in the train and eval modes, loading before eval should be directly skipped + if (theTrainBatchId == 0) { + cacheManager->Load(mgmtEmbInfo, step, trainKeysSet); } } @@ -361,10 +285,6 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures) if (featAdmitNEvict.GetFunctionSwitch()) { loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); } - - if (isSSDEnabled) { - loadFeatures.push_back(CkptFeatureType::DDR_KEY_FREQ_MAP); - } } /// 获取key对应的offset,python侧调用 @@ -437,76 +357,6 @@ void HybridMgmt::ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap) #endif } -/// 对加载的数据和训练配置进行一致性校验 -/// \param loadHostEmbs -/// \param setupHostEmbs -/// \param embTableCount -/// \return -bool HybridMgmt::IsLoadDataMatches(const EmbMemT& loadHostEmbs, - const EmbInfo& setupHostEmbs, - size_t& embTableCount) const -{ - bool loadDataMatches = { true }; - const auto& loadEmbTable { loadHostEmbs.find(setupHostEmbs.name) }; - if (loadEmbTable != loadHostEmbs.end()) { - embTableCount++; - - const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; - if (setupHostEmbs.sendCount != loadEmbInfo.sendCount) { - LOG_ERROR(MGMT + "Load data sendCount {} for table {} does not match setup sendCount {}", - setupHostEmbs.sendCount, setupHostEmbs.name, loadEmbInfo.sendCount); - loadDataMatches = false; - } - if (setupHostEmbs.extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { - LOG_ERROR(MGMT + "Load data extEmbeddingSize {} for table {} does not match setup extEmbeddingSize {}", - setupHostEmbs.extEmbeddingSize, setupHostEmbs.name, loadEmbInfo.extEmbeddingSize); - loadDataMatches = false; - } - if (setupHostEmbs.devVocabSize != loadEmbInfo.devVocabSize) { - LOG_ERROR(MGMT + "Load data devVocabSize {} for table {} does not match setup devVocabSize {}", - setupHostEmbs.devVocabSize, setupHostEmbs.name, loadEmbInfo.devVocabSize); - loadDataMatches = false; - } - if (setupHostEmbs.hostVocabSize != loadEmbInfo.hostVocabSize) { - LOG_ERROR(MGMT + "Load data hostVocabSize {} for table {} does not match setup hostVocabSize {}", - setupHostEmbs.hostVocabSize, setupHostEmbs.name, loadEmbInfo.hostVocabSize); - loadDataMatches = false; - } - if (!loadDataMatches) { - return false; - } - } else { - LOG_ERROR(MGMT + "Load data does not contain table with table name: {}", setupHostEmbs.name); - return false; - } - return true; -} - -/// 对DDR模式保存的模型和训练配置进行一致性校验 -/// \param loadData -/// \return 是否一致 -bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) -{ - size_t embTableCount { 0 }; - auto loadHostEmbs { loadData.hostEmbs }; - if (loadHostEmbs == nullptr) { - LOG_ERROR(MGMT + "Host Embedding of load checkpoint data is nullptr!"); - return false; - } - for (EmbInfo setupHostEmbs : mgmtEmbInfo) { - if (!IsLoadDataMatches(*loadHostEmbs, setupHostEmbs, embTableCount)) { - return false; - } - } - - if (embTableCount < loadHostEmbs->size()) { - LOG_ERROR(MGMT + "Load data has {} tables more than setup table num {}", - loadHostEmbs->size(), embTableCount); - return false; - } - return true; -} - /// 根据HBM/DDR模式,启动数据处理线程 void HybridMgmt::Start() { @@ -523,17 +373,17 @@ void HybridMgmt::Start() void HybridMgmt::StartThreadForHBM() { #ifndef GTEST - auto parseKeysTaskForHBMTrain = [this]() { - TrainTask(TaskType::HBM); - LOG_INFO("parseKeysTaskForHBMTrain done"); - }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); - - auto parseKeysTaskForHBMEval = [this]() { - EvalTask(TaskType::HBM); - LOG_INFO("parseKeysTaskForHBMEval done"); - }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); + auto parseKeysTaskForHBMTrain = [this]() { + TrainTask(TaskType::HBM); + LOG_INFO("parseKeysTaskForHBMTrain done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); + + auto parseKeysTaskForHBMEval = [this]() { + EvalTask(TaskType::HBM); + LOG_INFO("parseKeysTaskForHBMEval done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); #endif } @@ -551,6 +401,12 @@ void HybridMgmt::StartThreadForDDR() LOG_INFO("parseKeysTaskForEval done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); + + auto embeddingProcessTask = [this]() { + EmbeddingTask(); + LOG_INFO("embeddingProcessTask done"); + }; + procThreads.emplace_back(std::make_unique(embeddingProcessTask)); #endif } @@ -567,6 +423,17 @@ void HybridMgmt::Destroy() // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; + mutexDestroy = true; + for (const auto& embInfo : mgmtEmbInfo) { + for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { + cvLastUpdateFinishMap[embInfo.name][index].notify_all(); + cvLastLookUpFinishMap[embInfo.name][index].notify_all(); + cvLastSendFinishMap[embInfo.name][index].notify_all(); + cvLastRecvFinishMap[embInfo.name][index].notify_all(); + } + } + cvCheckSave.notify_all(); // 防止save异常退出场景阻塞在EvalTask + { // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 std::unique_lock lockGuard(KEY_PROCESS_INSTANCE->destroyMutex); @@ -584,22 +451,22 @@ void HybridMgmt::Destroy() if (cacheManager != nullptr) { cacheManager = nullptr; } - if (hostEmbs != nullptr) { - hostEmbs->Join(TRAIN_CHANNEL_ID); - hostEmbs->Join(EVAL_CHANNEL_ID); - hostEmbs = nullptr; - } + JoinEmbeddingCacheThread(); procThreads.clear(); // 停止预处理 KEY_PROCESS_INSTANCE->Destroy(); + // stop embCache, even if the host emb is still allocating + if (embCache != nullptr) { + embCache->Destroy(); + } LOG_DEBUG(MGMT + "Destroy hybrid_mgmt module end."); -}; +} -#ifndef GTEST /// 启动hybrid处理任务 /// \param type void HybridMgmt::TrainTask(TaskType type) { +#ifndef GTEST int channelId = TRAIN_CHANNEL_ID; int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[channelId]; do { @@ -612,19 +479,9 @@ void HybridMgmt::TrainTask(TaskType type) } LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, theTrainBatchId); - switch (type) { - case TaskType::HBM: - ParseKeysHBM(TRAIN_CHANNEL_ID, theTrainBatchId); - LOG_INFO(MGMT + "ParseKeysHBMBatchId = {}", theTrainBatchId); - break; - case TaskType::DDR: - ParseKeys(TRAIN_CHANNEL_ID, theTrainBatchId); - LOG_INFO(MGMT + "parseKeysBatchId = {}", theTrainBatchId); - break; - default: - throw std::invalid_argument("Invalid TaskType Type."); - } + ParseKeys(TRAIN_CHANNEL_ID, theTrainBatchId, type); } while (true); +#endif } /// 推理数据处理:数据处理状态正常,处理的batch数小于用户预设值或者设为-1时,会循环处理; @@ -632,11 +489,27 @@ void HybridMgmt::TrainTask(TaskType type) /// \return void HybridMgmt::EvalTask(TaskType type) { +#ifndef GTEST int channelId = EVAL_CHANNEL_ID; int& evalBatchId = hybridMgmtBlock->hybridBatchId[channelId]; do { hybridMgmtBlock->CheckAndSetBlock(channelId); if (hybridMgmtBlock->GetBlockStatus(channelId)) { + LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", evalBatchId, + hybridMgmtBlock->IsNeedWaitSave()); + std::unique_lock checkSaveLocker(saveMutex); + cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); + + if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { + // Before waking the data process for training, Recover the backed-up training state + RecoverTrainStatus(); + hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); + } else { + std::this_thread::sleep_for(SLEEP_MS); + continue; + } + + LOG_DEBUG("wake TrainTask"); hybridMgmtBlock->DoBlock(channelId); } if (!isRunning) { @@ -644,328 +517,244 @@ void HybridMgmt::EvalTask(TaskType type) } LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, evalBatchId); - switch (type) { - case TaskType::HBM: - ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); - LOG_INFO(MGMT + "HBM evalBatchId = {}", evalBatchId); - break; - case TaskType::DDR: - ParseKeys(EVAL_CHANNEL_ID, evalBatchId); - LOG_INFO(MGMT + "DDR evalBatchId = {}", evalBatchId); - break; - default: - throw std::invalid_argument("Invalid TaskType Type."); - } + ParseKeys(EVAL_CHANNEL_ID, evalBatchId, type); } while (true); +#endif } -/// HBM模式下,发送key process线程已处理好的各类型向量到指定通道中 -/// \param channelId 通道索引(训练/推理) -/// \param batchId 已处理的batch数 -/// \return -bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) -{ - LOG_INFO(MGMT + "nBatch:{} channelId:{} batchId:{}, ParseKeys with HBM mode start.", - mgmtRankInfo.nBatch, channelId, batchId); - - // 循环处理每个表的数据 - for (const auto& embInfo: mgmtEmbInfo) { - TimeCost parseKeysTc; - // 获取各类向量,如果为空指针,退出当前函数 - auto infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { - LOG_INFO(MGMT + "channelId:{} batchId:{}, ParseKeys infoVecs empty !", channelId, batchId); - return false; - } - LOG_DEBUG("channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end.", channelId, batchId); - // 动态shape场景下,获取all2all向量(通信量矩阵) - TimeCost sendTensorsSyncTC; - unique_ptr> all2all = nullptr; - if (!mgmtRankInfo.useStatic) { - TimeCost getTensorsSyncTC; - all2all = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); - LOG_DEBUG("channelId:{} batchId:{}, getTensorsSyncTC(ms):{}", - channelId, batchId, getTensorsSyncTC.ElapsedMS()); - if (all2all == nullptr) { - LOG_ERROR("Information vector is nullptr!"); - return false; - } - sendTensorsSyncTC = TimeCost(); // 重新初始化,不计算getTensors耗时 - TimeCost sendAll2AllScSyncTC; - hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); - LOG_DEBUG("channelId:{} batchId:{}, sendAll2AllScSyncTC(ms):{}", - channelId, batchId, sendAll2AllScSyncTC.ElapsedMS()); - } - - // 发送查询向量 - TimeCost sendLookupSyncTC; - hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); - infoVecs->pop_back(); - LOG_DEBUG("channelId:{} batchId:{}, sendLookupSyncTC(ms):{}", channelId, batchId, sendLookupSyncTC.ElapsedMS()); - - // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 - if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && - channelId == TRAIN_CHANNEL_ID) { - SendUniqKeysAndRestoreVecHBM(channelId, batchId, embInfo, infoVecs); - } - - // 发送恢复向量 - TimeCost sendRestoreSyncTC; - hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); - LOG_DEBUG("sendRestoreSyncTC(ms):{}, sendTensorsSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", - sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), parseKeysTc.ElapsedMS()); - LOG_INFO(MGMT + "channelId:{} batchId:{}, embName:{}, ParseKeys with HBM mode end.", - channelId, batchId, embInfo.name); - } - batchId++; - return true; -} - -void HybridMgmt::SendUniqKeysAndRestoreVecHBM(int channelId, int &batchId, const EmbInfo &embInfo, - const unique_ptr> &infoVecs) const +void HybridMgmt::SendUniqKeysAndRestoreVecHBM(const EmbBaseInfo& info, const unique_ptr>& infoVecs, + bool isGrad) const { TimeCost sendUniqueKeysSyncTC; - LOG_DEBUG("channelId:{} batchId:{}, global unique, table name: {}, is grad: {}", - channelId, batchId, embInfo.name, embInfo.isGrad); - if (embInfo.isGrad) { - hdTransfer->Send(TransferChannel::UNIQKEYS, {infoVecs->back()}, channelId, embInfo.name); + LOG_DEBUG("channelId:{} batchId:{}, global unique, table name: {}, is grad: {}", info.channelId, info.batchId, + info.name, isGrad); + if (isGrad) { + hdTransfer->Send(TransferChannel::UNIQKEYS, {infoVecs->back()}, info.channelId, info.name); } infoVecs->pop_back(); - LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", - channelId, batchId, sendUniqueKeysSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", info.channelId, info.batchId, + sendUniqueKeysSyncTC.ElapsedMS()); TimeCost sendUniqueRestoreVecSyncTC; - if (embInfo.isGrad) { - hdTransfer->Send(TransferChannel::RESTORE_SECOND, {infoVecs->back()}, channelId, embInfo.name); + if (isGrad) { + hdTransfer->Send(TransferChannel::RESTORE_SECOND, {infoVecs->back()}, info.channelId, info.name); } infoVecs->pop_back(); - LOG_DEBUG("channelId:{} batchId:{}, sendUniqueRestoreVecSyncTC(ms):{}", - channelId, batchId, sendUniqueRestoreVecSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueRestoreVecSyncTC(ms):{}", info.channelId, info.batchId, + sendUniqueRestoreVecSyncTC.ElapsedMS()); } -#endif - -/// 当前处理的batch是否是最后一个batch +/// 当前处理的batch是否是最后一个batch,涵盖train切换eval、save场景 /// \param batchId 已处理的batch数 -/// \param channelId 通道索引(训练/推理) /// \return -bool HybridMgmt::EndBatch(int batchId, int channelId) const +bool HybridMgmt::IsTrainEndBatch(int batchId) const { - return (batchId % mgmtRankInfo.ctrlSteps[channelId] == 0 && mgmtRankInfo.ctrlSteps[channelId] != -1); + // case 1:需要切eval + // case 2:需要save时,补发pos后被阻塞,等待save完成,避免embCache状态发送变化 + // batchId是从0开始的,所以要+1对上step + bool isNeedSwitchToEval = + mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID] != -1 && (batchId + 1) % mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID] == 0; + bool isNeedSave = mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] != -1 && mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] != 0 && + (batchId + 1) % mgmtRankInfo.ctrlSteps[SAVE_STEP_INDEX] == 0; + LOG_DEBUG("mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID]:{}, batchId:{}", mgmtRankInfo.ctrlSteps[TRAIN_CHANNEL_ID], + batchId); + LOG_DEBUG("isNeedSwitchToEval:{}, isNeedSave:{}", isNeedSwitchToEval, isNeedSave); + return isNeedSwitchToEval || isNeedSave; +} + +bool HybridMgmt::IsEvalEndBatch(int batchId) const +{ + // batchId是从0开始的,所以要+1对上step,表示当前step之后要结束eval了 + return (batchId + 1) == hybridMgmtBlock->stepsInterval[EVAL_CHANNEL_ID]; } /// DDR模式下,发送key process线程已处理好的各类型向量到指定通道中 /// \param channelId 通道索引(训练/推理) /// \param batchId 已处理的batch数 /// \return -bool HybridMgmt::ParseKeys(int channelId, int& batchId) +bool HybridMgmt::ParseKeys(int channelId, int& batchId, TaskType type) { #ifndef GTEST - LOG_INFO(MGMT + "channelId:{} batchId:{}, DDR mode, ParseKeys start.", channelId, batchId); + LOG_INFO(MGMT + "channelId:{} batchId:{}, ParseKeys start.", channelId, batchId); TimeCost parseKeyTC; - int start = batchId; - bool remainBatch = true; // 是否从通道获取了数据 + bool remainBatch = true; // 是否从通道获取了数据 + vector parseKeyThreadPool; for (const auto& embInfo : mgmtEmbInfo) { - ProcessEmbInfo(embInfo.name, batchId, channelId, remainBatch); - // 通道数据已空 - if (!remainBatch) { - LOG_DEBUG("last batch ending"); - return false; + EmbBaseInfo info = {.batchId = batchId, .channelId = channelId, .name = embInfo.name}; + switch (type) { + case TaskType::HBM: + parseKeyThreadPool.emplace_back( + [this, info, &remainBatch, embInfo]() { ProcessEmbInfoHBM(info, remainBatch, embInfo.isGrad); }); + break; + case TaskType::DDR: + if (!isL3StorageEnabled) { + parseKeyThreadPool.emplace_back( + [this, info, &remainBatch, embInfo]() { ProcessEmbInfoDDR(info, remainBatch); }); + } else { + parseKeyThreadPool.emplace_back( + [this, info, &remainBatch, embInfo]() { ProcessEmbInfoL3Storage(info, remainBatch); }); + } + break; + default: + throw std::invalid_argument("Invalid TaskType Type."); } } - batchId++; + for (auto& t : parseKeyThreadPool) { + t.join(); + } + // 通道数据已空 + if (!remainBatch) { + LOG_DEBUG("last batch ending"); + return false; + } if (!isRunning) { return false; } - EmbHDTransWrap(channelId, batchId - 1, start); - LOG_DEBUG(MGMT + "channelId:{} batchId:{}, ParseKeys end, parseKeyTC(ms):{}", - channelId, batchId, parseKeyTC.ElapsedMS()); + LOG_DEBUG(MGMT + "channelId:{} batchId:{}, ParseKeys end, parseKeyTC(ms):{}", channelId, batchId, + parseKeyTC.ElapsedMS()); + batchId++; #endif return true; } -void HybridMgmt::HandlePrepareDDRDataRet(TransferRet prepareSSDRet) const +void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut, bool isGrad) { - LOG_ERROR("Transfer embedding with DDR and SSD error."); - if (prepareSSDRet == TransferRet::SSD_SPACE_NOT_ENOUGH) { - LOG_ERROR("PrepareDDRData: SSD available space is not enough."); - throw runtime_error("ssdVocabSize too small"); + TimeCost parseKeysTc; + LOG_DEBUG("ProcessEmbInfoHBM table:{}, batchId:{}, channel:{}", info.name, info.batchId, info.channelId); + + // 获取各类向量,如果为空指针,退出当前函数 + bool isEos = false; + auto infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::RESTORE, isEos); + if (isEos) { + HandleEosCaseHBM(info.name, info.batchId, info.channelId, remainBatchOut); + return; } - if (prepareSSDRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { - LOG_ERROR("PrepareDDRData: DDR available space is not enough."); - throw runtime_error("ddrVocabSize too small"); + if (infoVecs == nullptr) { + LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", info.name, info.channelId, + info.batchId); + remainBatchOut = false; + return; } - throw runtime_error("Transfer embedding with DDR and SSD error."); -} + LOG_DEBUG("table:{}, channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end", info.name, info.channelId, + info.batchId); -#ifndef GTEST + // 动态shape场景下,获取all2all向量(通信量矩阵) + SendAll2AllVec(info, remainBatchOut); + if (!remainBatchOut) { + return; + } + + // 发送查询向量 + TimeCost sendLookupSyncTC; + hdTransfer->Send(TransferChannel::LOOKUP, {infoVecs->back()}, info.channelId, info.name); + infoVecs->pop_back(); + LOG_DEBUG("table:{}, channelId:{} batchId:{}, sendLookupSyncTC(ms):{}", info.name, info.channelId, info.batchId, + sendLookupSyncTC.ElapsedMS()); + + // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 + if (mgmtRankInfo.useSumSameIdGradients && info.channelId == TRAIN_CHANNEL_ID) { + SendUniqKeysAndRestoreVecHBM(info, infoVecs, isGrad); + } + + // 发送恢复向量 + TimeCost sendRestoreSyncTC; + hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name); + LOG_DEBUG("table:{}, sendRestoreSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", info.name, + sendRestoreSyncTC.ElapsedMS(), parseKeysTc.ElapsedMS()); + + LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, embName:{}, ParseKeys with HBM mode end.", info.name, + info.channelId, info.batchId, info.name); + + if (info.channelId == TRAIN_CHANNEL_ID) { + alreadyTrainOnce = true; + } +} /// 构造训练所需的各种向量数据 /// \param embName 表名 /// \param batchId 已处理的batch数 /// \param channelId 通道索引(训练/推理) /// \param remainBatchOut 是否从通道获取了数据 -/// \return HBM是否还有剩余空间 -bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, bool& remainBatchOut) +void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut) { +#ifndef GTEST TimeCost getAndSendTensorsTC; - TimeCost getTensorsTC; + LOG_DEBUG("ProcessEmbInfoDDR start, table:{}, channel:{}, batchId:{}", info.name, info.channelId, info.batchId); - if (hostHashMaps->embHashMaps.find(embName) == hostHashMaps->embHashMaps.end()) { - LOG_ERROR("Failed to get embedding hash map with given name: {}", embName); - return false; + if (info.channelId == TRAIN_CHANNEL_ID && info.batchId == hybridMgmtBlock->maxTrainStep) { + HandleReachMaxStepCase(info, remainBatchOut); + return; } - auto& embHashMap = hostHashMaps->embHashMaps.at(embName); - // 计数初始化 - std::shared_ptr table = EmbeddingMgmt::Instance()->GetTable(embName); - table->SetStartCount(); + // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS + // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoDDR最先调用,如需调整位置,请参考并适配其他函数 + // 获取GlobalUnique向量 + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + if (uniqueKeys.empty()) { + return; + } - // 获取查询向量 - auto lookupKeys = KEY_PROCESS_INSTANCE->GetLookupKeys(batchId, embName, channelId); - if (lookupKeys.empty()) { - remainBatchOut = false; - LOG_WARN("channelId:{} batchId:{}, embName:{}, GetLookupKeys result is empty.", channelId, batchId, embName); - return false; + // 获取GlobalUnique对应的restoreVectorSec + auto restoreVecSec = GetRestoreVecSec(info, remainBatchOut); + if (restoreVecSec.empty()) { + return; } - LOG_DEBUG("channelId:{} batchId:{}, embName:{}, GetLookupKeys end.", channelId, batchId, embName); - // 获取各类向量,如果为空指针,退出当前函数 - unique_ptr> infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embName, channelId, - ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { - LOG_ERROR("Information vector is nullptr!"); - return false; + + SendAll2AllVec(info, remainBatchOut); + if (!remainBatchOut) { + return; } - LOG_DEBUG("channelId:{} batchId:{}, GetInfoVec end, getTensorsTC(ms):{}", - channelId, batchId, getTensorsTC.ElapsedMS()); - TimeCost sendRestoreSyncTC; - hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); - LOG_DEBUG("channelId:{} batchId:{}, send restore end, sendRestoreSyncTC(ms):{}", - channelId, batchId, sendRestoreSyncTC.ElapsedMS()); + SendRestoreVec(info, remainBatchOut); + if (!remainBatchOut) { + return; + } - // 调用SSD cache缓存处理流程,获取锁避免保存时修改keyOffsetMap - table->mutSave_.lock(); - LOG_DEBUG("acquire save lock, table:{}", table->name); - PrepareDDRData(table, lookupKeys, channelId, batchId); + std::pair, vector> swapInKoPair; + std::pair, vector> swapOutKoPair; + GetSwapPairsAndKey2Offset(info, uniqueKeys, swapInKoPair, swapOutKoPair); - // 计算查询向量;记录需要被换出的HBM偏移 - vector tmpData; - vector offsetsOut; - DDRParam ddrParam(tmpData, offsetsOut); - TimeCost hostHashMapProcessTC; + SendLookupOffsets(info, uniqueKeys, restoreVecSec); - hostHashMaps->Process(embName, lookupKeys, ddrParam, channelId); - table->mutSave_.unlock(); - LOG_DEBUG("release save lock, table:{}", table->name); + SendGlobalUniqueVec(info, uniqueKeys, restoreVecSec); - LOG_DEBUG("channelId:{} batchId:{}, hostHashMapProcessTC(ms):{}", - channelId, batchId, hostHashMapProcessTC.ElapsedMS()); + TimeCost swapProcessTC; + auto& swapInPos = swapInKoPair.second; + auto& swapOutPos = swapOutKoPair.second; + auto lastSwapInPos = lastSwapInPosMap[info.name]; + lastSwapInPosMap[info.name] = swapInPos; // 暂存待下一步发送 - if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && - channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - SendUniqKeysAndRestoreVecDDR(embName, batchId, channelId, ddrParam); + auto isNeedReturn = HandleSpecialProcessStatusDDR(info, getAndSendTensorsTC, swapInKoPair, swapOutKoPair); + if (isNeedReturn) { + return; } - TimeCost sendTensorsTC; - hdTransfer->Send(TransferChannel::LOOKUP, { ddrParam.tmpDataOut.front() }, channelId, embName); - ddrParam.tmpDataOut.erase(ddrParam.tmpDataOut.cbegin()); - hdTransfer->Send(TransferChannel::SWAP, ddrParam.tmpDataOut, channelId, embName); - if (!mgmtRankInfo.useStatic) { - unique_ptr> all2all = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embName, - channelId, ProcessedInfo::ALL2ALL); - if (all2all == nullptr) { - LOG_ERROR("Information vector is nullptr!"); - return false; - } - hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); - } - LOG_DEBUG("channelId:{} batchId:{}, ProcessEmbInfo end, sendTensorsTC(ms):{}, getAndSendTensorsTC(ms):{}", - channelId, batchId, sendTensorsTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS()); + EnqueueSwapInfo(info, swapInKoPair, swapOutKoPair); - if (!isSSDEnabled && embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - LOG_WARN(MGMT + "channelId:{} batchId:{}, embName:{}, freeSize not enough:{}", - channelId, batchId, embName, lookupKeys.size()); - return false; + // 下发swaptensor + if (info.batchId != 0) { + SendTensorForSwap(info, lastSwapInPos, swapOutPos); } - return true; -} - -void HybridMgmt::SendUniqKeysAndRestoreVecDDR(const string &embName, int &batchId, int &channelId, DDRParam &ddrParam) -{ - LOG_DEBUG("channelId:{} batchId:{}, embName:{}, SendUniqKeysAndRestoreVecDDR start.", channelId, batchId, embName); - vector uniqueKeys; - vector restoreVecSec; - KEY_PROCESS_INSTANCE->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); - TimeCost sendUniqueKeysSyncTC; - hdTransfer->Send(TransferChannel::UNIQKEYS, {mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : - Vec2TensorI32(uniqueKeys) }, channelId, embName); - LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", - channelId, batchId, sendUniqueKeysSyncTC.ElapsedMS()); - - TimeCost sendRestoreVecSecSyncTC; - hdTransfer->Send(TransferChannel::RESTORE_SECOND, {Vec2TensorI32(restoreVecSec) }, channelId, embName); - LOG_DEBUG("channelId:{} batchId:{}, sendRestoreVecSecSyncTC(ms):{}", - channelId, batchId, sendRestoreVecSecSyncTC.ElapsedMS()); -} + HandleEndBatchCase(info, swapInPos); -/// 发送H2D和接收D2H向量 -/// \param channelId 通道索引(训练/推理) -/// \param batchId 已处理的batch数 -/// \param start -void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start) -{ - LOG_INFO(MGMT + "start:{} channelId:{} batchId:{}, EmbHDTransWrap start.", start, channelId, batchId); - TimeCost embHDTransWrapTC; - TimeCost hostEmbsTC; - hostEmbs->Join(channelId); - LOG_DEBUG("channelId:{} batchId:{}, hostEmbs Join end, hostEmbsTC(ms):{}", - channelId, batchId, hostEmbsTC.ElapsedMS()); - if (!isRunning) { - return; + if (info.channelId == TRAIN_CHANNEL_ID) { + alreadyTrainOnce = true; } - EmbHDTrans(channelId, batchId); - LOG_DEBUG("channelId:{} batchId:{}, EmbHDTransWrap end, embHDTransWrapTC(ms):{}", - channelId, batchId, embHDTransWrapTC.ElapsedMS()); -} -/// 发送H2D和接收D2H向量,并更新host emb -/// \param channelId 通道索引(训练/推理) -/// \param batchId 已处理的batch数 -void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) -{ - EASY_FUNCTION(profiler::colors::Blue) - EASY_VALUE("mgmtProcess", batchId) - LOG_DEBUG(MGMT + "channelId:{} batchId:{}, EmbHDTrans start.", channelId, batchId); - TimeCost h2dTC; - // 发送host需要换出的emb - for (const auto& embInfo: mgmtEmbInfo) { - const auto& missingKeys = EmbeddingMgmt::Instance()->GetMissingKeys(embInfo.name); - vector h2dEmb; - hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! - hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); - } - LOG_DEBUG("channelId:{} batchId:{}, EmbHDTrans h2d end, h2dTC(ms):{}", channelId, batchId, h2dTC.ElapsedMS()); - - TimeCost d2hTC; - // 接收device换出的emb,并更新到host上 - for (const auto& embInfo: mgmtEmbInfo) { - const auto& missingKeys = EmbeddingMgmt::Instance()->GetMissingKeys(embInfo.name); - hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! - EmbeddingMgmt::Instance()->ClearMissingKeys(embInfo.name); - } - LOG_DEBUG("channelId:{} batchId:{}, EmbHDTrans d2h end, d2hTC(ms):{}", channelId, batchId, d2hTC.ElapsedMS()); -} + LOG_DEBUG("ProcessEmbInfoDDR end, table:{}, channel:{}, batchId:{} swapProcessTC(ms):{} getAndSendTensorsTC(ms):{}", + info.name, info.channelId, info.batchId, swapProcessTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS()); #endif +} /// hook通过时间或者step数触发淘汰 /// \return bool HybridMgmt::Evict() { #ifndef GTEST + std::lock_guard lk(evictMut); if (!isInitialized) { throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } @@ -996,13 +785,20 @@ bool HybridMgmt::Evict() } } else { if (GlobalEnv::useCombineFaae) { - for (auto& map : hostHashMaps->embHashMaps) { - EmbeddingMgmt::Instance()->EvictKeys(map.first, evictKeyMap[COMBINE_HISTORY_NAME]); + vector allTableNames; + int retCode = embCache->GetEmbTableNames(allTableNames); + if (retCode != H_OK) { + LOG_ERROR("GetEmbTableNames failed!"); + return false; + } + for (const string& embName : allTableNames) { + EvictKeys(embName, evictKeyMap[COMBINE_HISTORY_NAME]); + EvictL3StorageKeys(embName, evictKeyMap[COMBINE_HISTORY_NAME]); } } else { for (const auto& evict : as_const(evictKeyMap)) { EvictKeys(evict.first, evict.second); - EvictSSDKeys(evict.first, evict.second); + EvictL3StorageKeys(evict.first, evict.second); } } } @@ -1014,90 +810,24 @@ bool HybridMgmt::Evict() /// DDR模式下的淘汰:删除映射表、初始化host表、发送dev淘汰位置 /// \param embName /// \param keys -void HybridMgmt::EvictKeys(const string& embName, const vector& keys) -{ - std::shared_ptr table = EmbeddingMgmt::Instance()->GetTable(embName); - - table->EvictKeys(keys); - - const vector& evictOffsetDev = table->GetEvictedKeys(); - const vector& evictOffsetHost = table->GetHostEvictedKeys(); - - vector evictOffsetHostx(evictOffsetHost); - - size_t devVocabSize = table->GetDevVocabSize(); - for (int64_t& key: evictOffsetHostx) { - key -= static_cast(devVocabSize); - }; - - /* 淘汰Host侧 */ - if (!evictOffsetHost.empty()) { - hostEmbs->EvictInitEmb(embName, evictOffsetHost); - } - - vector tmpDataOut; - Tensor tmpData = Vec2TensorI32(evictOffsetDev); - tmpDataOut.emplace_back(tmpData); - tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); - - auto evictLen = tmpDataOut.back().flat(); - auto evictSize = static_cast(evictOffsetDev.size()); - evictLen(0) = evictSize; - - hdTransfer->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); -} - -inline void HybridMgmt::PrepareDDRData(std::shared_ptr table, - const vector& keys, int channelId, int batchId) const +void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { - if (!isSSDEnabled) { + if (keys.empty()) { return; } - LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData start.", channelId, batchId, table->name); - TimeCost prepareDDRDataTc; - TableInfo ti = table->GetTableInfo(); - TransferRet ret = cacheManager->TransferDDREmbWithSSD(ti, keys, channelId); - if (ret != TransferRet::TRANSFER_OK) { - HandlePrepareDDRDataRet(ret); - } - LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData end, prepareDDRDataTc(ms):{}", - channelId, batchId, table->name, prepareDDRDataTc.ElapsedMS()); -} - -void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) const -{ - if (!isSSDEnabled) { + int retCode = embCache->RemoveEmbsByKeys(embName, keys); + if (retCode != H_OK) { + LOG_ERROR("RemoveEmbsByKeys failed!"); return; } - vector ssdKeys; - for (auto& key : keys) { - if (cacheManager->IsKeyInSSD(embName, key)) { - ssdKeys.emplace_back(key); - } - } - cacheManager->EvictSSDEmbedding(embName, ssdKeys); } -int HybridMgmt::GetStepFromPath(const string& loadPath) const +void HybridMgmt::EvictL3StorageKeys(const string& embName, const vector& keys) const { - regex pattern("sparse-model-(\\d+)"); - smatch match; - if (regex_search(loadPath, match, pattern)) { - int res = 0; - unsigned int minSize = 2; - if (match.size() < minSize) { - return res; - } - try { - res = stoi(match[1]); - } catch (const std::invalid_argument& e) { - LOG_ERROR(e.what()); - } catch (const std::out_of_range& e) { - LOG_ERROR(e.what()); - } - return res; + if (!isL3StorageEnabled) { + return; } - return 0; + cacheManager->EvictL3StorageEmbedding(embName, keys); } /// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行,需要进行唤醒 @@ -1129,38 +859,33 @@ void HybridMgmt::CountStepBySessionRun(int channelID, int steps) const /// \return 表使用大小 int64_t HybridMgmt::GetTableSize(const string& embName) const { + int64_t size = -1; #ifndef GTEST if (!isInitialized) { throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } if (mgmtRankInfo.useDynamicExpansion) { - int64_t size = EmbeddingMgmt::Instance()->GetSize(embName); + size = EmbeddingMgmt::Instance()->GetSize(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] size:{}", embName, size); return size; } if (!mgmtRankInfo.isDDR) { size_t maxOffset = EmbeddingMgmt::Instance()->GetMaxOffset(embName); - int64_t size = static_cast(maxOffset); + size = static_cast(maxOffset); LOG_INFO(MGMT + "HBM mode, get emb:[{}] size:{}", embName, size); return size; } - int64_t ssdSize = 0; - if (mgmtRankInfo.isSSDEnabled) { - ssdSize = cacheManager->GetTableEmbeddingSize(embName); - } - - const auto& iter = hostHashMaps->embHashMaps.find(embName); - if (iter == hostHashMaps->embHashMaps.end()) { - LOG_ERROR(MGMT + "get maxOffset, wrong embName:{} ", embName); - return -1; + int64_t l3StorageUsage = 0; + if (isL3StorageEnabled) { + l3StorageUsage = cacheManager->GetTableUsage(embName); } - auto maxOffset = hostHashMaps->embHashMaps.at(embName).maxOffset; - int64_t size = static_cast(maxOffset) + ssdSize; - LOG_INFO(MGMT + "DDR/SSD mode, get emb:[{}] size:{}", embName, size); - return size; + uint32_t ddrSize = embCache->GetUsage(embName); + size = static_cast(ddrSize) + l3StorageUsage; + LOG_INFO(MGMT + "DDR/L3Storage mode, get emb:[{}] size:{}", embName, size); #endif + return size; } /// 获取table表容量大小 @@ -1179,8 +904,8 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const return capacity; } LOG_WARN(MGMT + "no dynamic expansion mode, get emb:[{}] capacity failed", embName); - return -1; #endif + return -1; } /// 设置表的优化器信息 @@ -1194,3 +919,1330 @@ void HybridMgmt::SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo } EmbeddingMgmt::Instance()->SetOptimizerInfo(embName, optimInfo); } + +// L3Storage +void HybridMgmt::LookUpAndRemoveAddrs(const EmbTaskInfo& info) +{ + uint64_t memSize = info.extEmbeddingSize * sizeof(float); + const std::string hbmSwapKeyQueName = "HBMSwapKeyQue"; + const std::string ddrSwapKeyQueName = "DDRSwapKeyQue"; + auto lookUpFunc = [this, memSize, info](std::map>>& fromQue, + std::map>>& toQue, + const string& swapStr, const string& fromQueName) { + std::vector keys = fromQue[info.name + swapStr].WaitAndPop(); + if (!isRunning) { + return; + } + std::vector addrs; + TimeCost lookupAddrsTC; + int rc = embCache->EmbeddingLookupAddrs(info.name, keys, addrs); + if (rc != H_OK) { + LOG_ERROR("lookUpAddrs, table:{}, fromQue: {}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}", + info.name, fromQueName, swapStr, keys.size(), addrs.size(), info.batchId); + throw runtime_error("EmbeddingLookupAddrs failed! error code:" + std::to_string(rc)); + } + if (&fromQue == &DDRSwapKeyQue && swapStr == SWAP_OUT_STR) { + for (auto& addr : addrs) { + auto* newAddr = (float*)malloc(memSize); + rc = memcpy_s(newAddr, memSize, addr, memSize); + if (rc != 0) { + throw runtime_error("memcpy_s failed! error code:" + std::to_string(rc)); + } + addr = newAddr; + } + rc = embCache->EmbeddingRemove(info.name, keys); + if (rc != H_OK) { + throw runtime_error("EmbeddingRemove failed! error code:" + std::to_string(rc)); + } + } + LOG_DEBUG("table:{}, fromQue:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsTC(ms):{}", + info.name, fromQueName, swapStr, keys.size(), addrs.size(), info.batchId, lookupAddrsTC.ElapsedMS()); + toQue[info.name + swapStr].Pushv(addrs); + }; + + lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_OUT_STR, ddrSwapKeyQueName); + lookUpFunc(DDRSwapKeyQue, DDRSwapAddrsQue, SWAP_IN_STR, ddrSwapKeyQueName); + lookUpFunc(HBMSwapKeyQue, HBMSwapAddrsQue, SWAP_IN_STR, hbmSwapKeyQueName); + lookUpFunc(HBMSwapKeyQue, HBMSwapAddrsQue, SWAP_OUT_STR, hbmSwapKeyQueName); + lookUpSwapInAddrsPushId[info.name]++; +} + +// DDR +void HybridMgmt::LookUpSwapAddrs(const string& embName) +{ + int id = 0; + std::string swapInName = embName + SWAP_IN_STR; + std::string swapOutName = embName + SWAP_OUT_STR; + std::vector addrs; + while (isRunning && lookupAddrSuccess) { + if (!isRunning) { + return; + } + // swap in + std::vector keys = HBMSwapKeyQue[swapInName].WaitAndPop(); + TimeCost lookupAddrsInTC; + int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); + if (rc != H_OK) { + lookupAddrSuccess = false; + throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); + } + LOG_DEBUG("table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsInTC(ms):{}", embName, + SWAP_IN_STR, keys.size(), addrs.size(), id, lookupAddrsInTC.ElapsedMS()); + HBMSwapAddrsQue[swapInName].Pushv(addrs); + + lookUpSwapInAddrsPushId[embName]++; + LOG_DEBUG("LookUpSwapAddrs, table:{}, pushId:{}, lookUpSwapInAddrsPushId:{}", embName, id, + lookUpSwapInAddrsPushId[embName]); + + // swap out + keys = HBMSwapKeyQue[swapOutName].WaitAndPop(); + TimeCost lookupAddrsOutTC; + rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); + if (rc != H_OK) { + lookupAddrSuccess = false; + throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); + } + LOG_DEBUG("table:{}, swapStr:{}, keys.size:{}, addrs.size:{}, pushId:{}, lookupAddrsOutTC(ms):{}", embName, + SWAP_OUT_STR, keys.size(), addrs.size(), id, lookupAddrsOutTC.ElapsedMS()); + HBMSwapAddrsQue[swapOutName].Pushv(addrs); + id++; + } +} + +/// 导出npu的embedding +void HybridMgmt::FetchDeviceEmb() +{ + // 数据处理线程上锁 + KEY_PROCESS_INSTANCE->LoadSaveLock(); + + if (mgmtRankInfo.isDDR) { + // DDR模式保存host的emb表以及hashmap + LOG_DEBUG(MGMT + "start host side save: ddr mode"); + for (const auto& embInfo : mgmtEmbInfo) { + std::vector> koVec; + embCache->ExportDeviceKeyOffsetPairs(embInfo.name, koVec); + std::vector swapOutPos; + for (const auto& p : koVec) { + swapOutPos.push_back(p.second); + } + + vector swapTensor; + swapTensor.emplace_back(Vec2TensorI32(swapOutPos)); + swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, {1})); + auto swapOutLen = swapTensor.back().flat(); + swapOutLen(0) = swapOutPos.size(); + LOG_DEBUG(MGMT + "save swapOutPos size:{}", swapOutPos.size()); + // 发送SwapOutPos信息 + hdTransfer->Send(TransferChannel::SAVE_H2D, swapTensor, TRAIN_CHANNEL_ID, embInfo.name); + } + } + KEY_PROCESS_INSTANCE->LoadSaveUnlock(); +} + +// 这里就是新增的embedding处理线程 +void HybridMgmt::EmbeddingTask() +{ + for (const auto& embInfo : mgmtEmbInfo) { + lastUpdateFinishStepMap[embInfo.name] = 0; + lastLookUpFinishStepMap[embInfo.name] = 0; + lastSendFinishStepMap[embInfo.name] = 0; + lastRecvFinishStepMap[embInfo.name] = 0; + } + + TimeCost embHDTransTC; + MultiThreadEmbHDTransWrap(); + LOG_DEBUG("embHDTransTC(ms):{}", embHDTransTC.ElapsedMS()); +} + +void HybridMgmt::MultiThreadEmbHDTransWrap() +{ + for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { + for (const auto& embInfo : mgmtEmbInfo) { + CreateEmbeddingLookUpAndSendThread(index, embInfo); + CreateEmbeddingReceiveAndUpdateThread(index, embInfo); + } + } +} + +void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo& embInfo) +{ + int cvNotifyIndex = 0; + if (index + 1 != EMBEDDING_THREAD_NUM) { + cvNotifyIndex = index + 1; + } + + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; + vector h2dEmb; + + auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); + if (!isSuccess) { + LOG_INFO("HybridMgmt is not running"); + return; + } + + EmbeddingSendDDR(info, h2dEmb); +} + +void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbInfo& embInfo) +{ + int cvNotifyIndex = 0; + if (index + 1 != EMBEDDING_THREAD_NUM) { + cvNotifyIndex = index + 1; + } + + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; + + float* ptr = nullptr; + vector swapOutAddrs; + auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs); + if (!isSuccess) { + LOG_INFO("HybridMgmt is not running"); + return; + } + + EmbeddingUpdateDDR(info, ptr, swapOutAddrs); +} + +void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo) +{ + int cvNotifyIndex = 0; + if (index + 1 != EMBEDDING_THREAD_NUM) { + cvNotifyIndex = index + 1; + } + + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; + vector h2dEmb; + + auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); + if (!isSuccess) { + LOG_INFO("HybridMgmt is not running"); + return; + } + + EmbeddingSendL3Storage(info, h2dEmb); +} + +void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, const EmbInfo& embInfo) +{ + int cvNotifyIndex = 0; + if (index + 1 != EMBEDDING_THREAD_NUM) { + cvNotifyIndex = index + 1; + } + + EmbTaskInfo info = {.batchId = batchId, + .threadIdx = index, + .cvNotifyIndex = cvNotifyIndex, + .extEmbeddingSize = embInfo.extEmbeddingSize, + .name = embInfo.name}; + + float* ptr = nullptr; + vector swapOutAddrs; + int64_t dims0 = 0; + EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); + + EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); +} + +/// 构造训练所需的各种向量数据 +/// \param embName 表名 +/// \param batchId 已处理的batch数 +/// \param channelId 通道索引(训练/推理) +/// \param remainBatchOut 是否从通道获取了数据 +/// \return 是否处理成功 +void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBatchOut) +{ +#ifndef GTEST + TimeCost getAndSendTensorsTC; + LOG_DEBUG("ProcessEmbInfoL3Storage table:{}, channel:{}, batchId:{}", info.name, info.channelId, info.batchId); + + if (info.channelId == TRAIN_CHANNEL_ID && info.batchId == hybridMgmtBlock->maxTrainStep) { + HandleReachMaxStepCase(info, remainBatchOut); + return; + } + + // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS + // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoL3Storage最先调用,如需调整位置,请参考并适配其他函数 + // 获取GlobalUnique向量 + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + if (uniqueKeys.empty()) { + return; + } + + // 获取GlobalUnique对应的restoreVectorSec + auto restoreVecSec = GetRestoreVecSec(info, remainBatchOut); + if (restoreVecSec.empty()) { + return; + } + + SendAll2AllVec(info, remainBatchOut); + if (!remainBatchOut) { + return; + } + + SendRestoreVec(info, remainBatchOut); + if (!remainBatchOut) { + return; + } + + std::pair, vector> swapInKoPair; + std::pair, vector> swapOutKoPair; + GetSwapPairsAndKey2Offset(info, uniqueKeys, swapInKoPair, swapOutKoPair); + + SendLookupOffsets(info, uniqueKeys, restoreVecSec); + + SendGlobalUniqueVec(info, uniqueKeys, restoreVecSec); + + TimeCost swapProcessTC; + auto& swapInKeys = swapInKoPair.first; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; + auto lastSwapInPos = lastSwapInPosMap[info.name]; + lastSwapInPosMap[info.name] = swapInPos; // 暂存待下一步发送 + + auto isNeedReturn = HandleSpecialProcessStatusL3Storage(info, getAndSendTensorsTC, swapInKoPair, swapOutKoPair); + if (isNeedReturn) { + return; + } + + HandleDataSwapForL3Storage(info, swapInKeys, swapOutKeys); + + // 下发swaptensor + if (info.batchId != 0) { + SendTensorForSwap(info, lastSwapInPos, swapOutPos); + } + + HandleEndBatchCase(info, swapInPos); + + if (info.channelId == TRAIN_CHANNEL_ID) { + alreadyTrainOnce = true; + } + + LOG_DEBUG("ProcessEmbInfoL3Storage end, table:{}, batchId:{}, swapProcessTC(ms):{}, getAndSendTensorsTC(ms):{}", + info.name, info.batchId, swapProcessTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS()); +#endif +} + +void HybridMgmt::SendTensorForSwap(const EmbBaseInfo& info, const vector& swapInPosUint, + const vector& swapOutPosUint) +{ +#ifndef GTEST + vector swapTensor; + swapTensor.emplace_back(Vec2TensorI32(swapInPosUint)); + swapTensor.emplace_back(Vec2TensorI32(swapOutPosUint)); + swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, {1})); + auto swapInLen = swapTensor.back().flat(); + swapInLen(0) = swapInPosUint.size(); + swapTensor.emplace_back(Tensor(tensorflow::DT_INT32, {1})); + auto swapOutLen = swapTensor.back().flat(); + swapOutLen(0) = swapOutPosUint.size(); + + hdTransfer->Send(TransferChannel::SWAP, swapTensor, info.channelId, info.name, info.batchId); +#endif +} + +void HybridMgmt::InitDataPipelineForDDR(const string& embName) +{ + // 初始化公共队列 + HBMSwapKeyQue[embName + SWAP_IN_STR]; + HBMSwapKeyQue[embName + SWAP_OUT_STR]; + HBMSwapAddrsQue[embName + SWAP_IN_STR]; + HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + + // 初始化lookup线程 + lookUpSwapInAddrsPushId[embName]; // 此处初始化,避免多线程竞争导致计数错误 + lookUpSwapInAddrsThreads.emplace_back( + std::async(std::launch::async, [=] { LookUpSwapAddrs(embName); })); + + LOG_DEBUG("data pipeline for ddr init"); +} + +void HybridMgmt::InitDataPipelineForL3Storage(const string& embName, int extEmbeddingSize) +{ + // 初始化公共队列 + HBMSwapKeyQue[embName + SWAP_IN_STR]; + HBMSwapKeyQue[embName + SWAP_OUT_STR]; + HBMSwapAddrsQue[embName + SWAP_IN_STR]; + HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + + HBMSwapKeyQue[embName + ADDR_STR]; + HBMSwapKeyForL3StorageQue[embName + SWAP_IN_STR]; + HBMSwapKeyForL3StorageQue[embName + ADDR_STR]; + HBMSwapKeyForL3StorageQue[embName + SWAP_OUT_STR]; + + DDRSwapKeyQue[embName + SWAP_OUT_STR]; + DDRSwapKeyQue[embName + SWAP_IN_STR]; + DDRSwapKeyForL3StorageQue[embName + SWAP_OUT_STR]; + DDRSwapKeyForL3StorageQue[embName + SWAP_IN_STR]; + DDRSwapAddrsQue[embName + SWAP_OUT_STR]; + DDRSwapAddrsQue[embName + SWAP_IN_STR]; + + // 初始化lookup线程 + LOG_DEBUG("data pipeline for L3Storage init"); +} + +void HybridMgmt::InitEmbeddingCache(const vector& embInfos) +{ + factory->SetExternalLogFuncInner(CTRLog); + factory->CreateEmbCacheManager(embCache); + EmbeddingMgmt::Instance()->SetEmbCacheForEmbTable(embCache); + EmbeddingMgmt::Instance()->SetHDTransferForEmbTable(hdTransfer); + + for (auto embInfo : embInfos) { + if (isL3StorageEnabled) { + InitDataPipelineForL3Storage(embInfo.name, embInfo.extEmbeddingSize); + } else { + InitDataPipelineForDDR(embInfo.name); + } + + specialProcessStatus[embInfo.name] = ProcessStatus::NORMAL; + + // 初始化embedding cache + LOG_INFO("create cache for table:{}, hostVocabSize:{}, extEmbeddingSize:{}, maxCacheSize(devVocabSize):{}", + embInfo.name, embInfo.hostVocabSize, embInfo.extEmbeddingSize, embInfo.devVocabSize); + EmbCache::EmbCacheInfo embCacheInfo(embInfo.name, embInfo.hostVocabSize, embInfo.embeddingSize, + embInfo.extEmbeddingSize, embInfo.devVocabSize); + size_t prefill = std::max(embInfo.hostVocabSize / HOST_TO_PREFILL_RATIO, embInfo.devVocabSize); + int ret = embCache->CreateCacheForTable(embCacheInfo, embInfo.initializeInfos, INVALID_KEY_VALUE, prefill, + EMBEDDING_THREAD_NUM); + if (ret != H_OK) { + throw runtime_error(embInfo.name + "create cache for table failed, error code: " + std::to_string(ret)); + } + } +} + +void HybridMgmt::JoinEmbeddingCacheThread() +{ + for (auto& p : HBMSwapAddrsQue) { + p.second.DestroyQueue(); + } + for (auto& p : HBMSwapKeyQue) { + p.second.DestroyQueue(); + } + for (auto& p : HBMSwapKeyForL3StorageQue) { + p.second.DestroyQueue(); + } + for (auto& p : DDRSwapKeyQue) { + p.second.DestroyQueue(); + } + for (auto& p : DDRSwapKeyForL3StorageQue) { + p.second.DestroyQueue(); + } + for (auto& p : DDRSwapAddrsQue) { + p.second.DestroyQueue(); + } + for (auto& t : EmbeddingLookUpAndSendThreadPool) { + t.join(); + } + for (auto& t : EmbeddingReceiveAndUpdateThreadPool) { + t.join(); + } + for (auto& t : lookUpSwapInAddrsThreads) { + t.wait(); + } + for (auto& t : lookUpSwapOutAddrsThreads) { + t.wait(); + } +} + +void HybridMgmt::HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBatchOut) +{ + // 1. 如果没有切换过,即状态normal,就该send以结束step n-1 + // 2. 如果切换过: + // a. eval场景跑完,不用send,外面自然退出 + // b. save场景,能触发,说明期望的train step已经跑完(由IsTrainEndBatch判定send),当前step也不用send + LOG_DEBUG("table:{}, batchId:{}, ProcessStatus:{}, reach maxTrainStep", info.name, info.batchId, + ProcessStatus2Str(ProcessStatus::NORMAL)); + if (specialProcessStatus[info.name] == ProcessStatus::NORMAL) { + LOG_DEBUG("table:{}, batchId:{}, need send swap tensor" + " for last step to finish train", + info.name, info.batchId); + std::vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + } else { + LOG_DEBUG("table:{}, batchId:{}, switch from eval or save, unnecessary to send emptySwapOutPos", info.name, + info.batchId); + } + remainBatchOut = false; + hybridMgmtBlock->SetBlockStatus(TRAIN_CHANNEL_ID, true); +} + +void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) +{ + LOG_INFO("GetUniqueKeys get eos, handle final batch for current epoch, table:{}, channel:{}, batchId:{}", info.name, + info.channelId, info.batchId); + bool sendAllChannel = false; + if (info.channelId == TRAIN_CHANNEL_ID) { + vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + LOG_INFO("GetUniqueKeys get eos, send pos for train channel, table:{}, batchId:{}", info.name, info.batchId); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); + remainBatchOut = false; + return; + } + + if (!alreadyTrainOnce) { + // predict场景 + LOG_INFO("ProcessEmbInfoDDR first run in eval channel, assume as predict mode, start handle eos"); + std::vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + sendAllChannel = true; + } else { + hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); + LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); + if (hybridMgmtBlock->IsNeedWaitSave()) { + // train+eval+save场景 + // 当前step n之后需要save,涉及save到train的状态切换。需要: + // 1. 补发pos以启动eval step n-1并完成。 + // 2. eval step n遇到eos结束 + // 3. 开始save,完成后唤醒train的ProcessEmbInfoDDR,所以需要在此之前改变specialProcessStatus + LOG_DEBUG("eval encounter eos and need save after this step" + "send pos change specialProcessStatus, current status:{}, modify to status:{}", + ProcessStatus2Str(specialProcessStatus[info.name]), + ProcessStatus2Str(ProcessStatus::AFTER_SWITCH_FIRST_BATCH)); + vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; + } else { + // train+eval+train场景 + // 交给train的ProcessEmbInfoDDR启动最后n-1步eval + // train发送pos让eval step n-1跑完,到eval step n时各channel遇到eos后结束(train、eval共享的channel除外) + LOG_INFO("GetUniqueKeys get eos, skip send pos for eval channel, table:{}, batchId:{}", info.name, + info.batchId); + } + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); + remainBatchOut = false; +} + +bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs) +{ + std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); + cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { + return (lastRecvFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + if (!isRunning) { + return false; + } + TimeCost EmbeddingRecvTC = TimeCost(); + + swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); + if (!isRunning) { + return false; + } + // 等待图执行发送d2h embedding过来 + if (info.batchId != 0) { + TransferChannel transferName = TransferChannel::D2H; + auto size = hdTransfer->RecvAcl(transferName, TRAIN_CHANNEL_ID, info.name, info.threadIdx, info.batchId); + if (size == 0) { + LOG_WARN(HOSTEMB + "recv empty data"); + return false; + } + + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[info.name][info.threadIdx], 0); + if (aclData == nullptr) { + throw runtime_error("Acl get tensor data from dataset failed."); + } + ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + + // 判断拿到的embedding个数是否与swapOutKeys个数相等 + size_t dimNum = acltdtGetDimNumFromItem(aclData); + int64_t dims[dimNum]; + acltdtGetDimsFromItem(aclData, dims, dimNum); + + LOG_DEBUG("table:{}, batchId:{}, dims[0]:{}, swapOutAddrs size:{}", info.name, info.batchId, dims[0], + swapOutAddrs.size()); + + if (dims[0] != static_cast(swapOutAddrs.size())) { + throw runtime_error("data dims[0] != swapOutKeys.size()"); + } + } + LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name, info.batchId, info.threadIdx, + EmbeddingRecvTC.ElapsedMS()); + lastRecvFinishStepMap[info.name]++; + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + + return true; +} + +void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs) +{ + std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); + cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { + return (lastUpdateFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + TimeCost EmbeddingUpdateTC = TimeCost(); + + uint64_t memSize = info.extEmbeddingSize * sizeof(float); + uint64_t extEmbeddingSize = info.extEmbeddingSize; +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(swapOutAddrs, embPtr, extEmbeddingSize, memSize) + for (uint64_t i = 0; i < swapOutAddrs.size(); i++) { + auto rc = memcpy_s(swapOutAddrs[i], memSize, embPtr + i * extEmbeddingSize, memSize); + if (rc != 0) { + throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); + } + } + if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) { + string sample; + if (!swapOutAddrs.empty()) { + sample = FloatPtrToLimitStr(swapOutAddrs.front(), info.extEmbeddingSize); // print first element + } + LOG_DEBUG("table:{}, batchId:{}, thread:{}, receive d2hEmb, ext emb:{}, emb size:{}, emb samples:{}, " + "EmbeddingUpdateTC(ms):{}", + info.name.c_str(), info.batchId, info.threadIdx, info.extEmbeddingSize, swapOutAddrs.size(), sample, + EmbeddingUpdateTC.ElapsedMS()); + } + + lastUpdateFinishStepMap[info.name]++; + cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); +} + +bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb) +{ + std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); + cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { + return (lastUpdateFinishStepMap[info.name] >= info.batchId) || mutexDestroy; + }); + if (!isRunning) { + return false; + } + + std::unique_lock lastLookUpFinishLocker(lastLookUpFinishMutexMap[info.name][info.threadIdx]); + cvLastLookUpFinishMap[info.name][info.threadIdx].wait(lastLookUpFinishLocker, [info, this] { + return (lastLookUpFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + if (!isRunning) { + return false; + } + + bool isSuccess = BuildH2DEmbedding(info, h2dEmb); + if (!isSuccess) { + return false; + } + + lastLookUpFinishStepMap[info.name]++; + cvLastLookUpFinishMap[info.name][info.cvNotifyIndex].notify_all(); + + return true; +} + +void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb) +{ + std::unique_lock lastSendFinishLocker(lastSendFinishMutexMap[info.name][info.threadIdx]); + cvLastSendFinishMap[info.name][info.threadIdx].wait(lastSendFinishLocker, [info, this] { + return (lastSendFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + TimeCost SendTC = TimeCost(); + hdTransfer->Send(TransferChannel::H2D, h2dEmb, TRAIN_CHANNEL_ID, info.name, info.batchId); + lastSendFinishStepMap[info.name]++; + cvLastSendFinishMap[info.name][info.cvNotifyIndex].notify_all(); + LOG_DEBUG("table:{}, batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name, info.batchId, info.threadIdx, + SendTC.ElapsedMS()); + + // 对于end of sequence场景,key + // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel + hybridMgmtBlock->h2dNextBatchId[info.name]++; + LOG_DEBUG("h2dNextBatchId, table:{}, next batchId:{}", info.name, hybridMgmtBlock->h2dNextBatchId[info.name]); +} + +void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo) +{ + EmbeddingLookUpAndSendThreadPool.emplace_back([index, embInfo, this]() { + while (true) { + lookUpAndSendBatchIdMtx.lock(); + if (lookUpAndSendTableBatchMap[embInfo.name] % EMBEDDING_THREAD_NUM == index) { + int cur_batch_id = lookUpAndSendTableBatchMap[embInfo.name]; + lookUpAndSendTableBatchMap[embInfo.name]++; + lookUpAndSendBatchIdMtx.unlock(); + if (!isL3StorageEnabled) { + EmbeddingLookUpAndSendDDR(cur_batch_id, index, embInfo); + } else { + EmbeddingLookUpAndSendL3Storage(cur_batch_id, index, embInfo); + } + } else { + lookUpAndSendBatchIdMtx.unlock(); + } + if (!isRunning) { + return; + } + } + }); +} + +void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo) +{ + EmbeddingReceiveAndUpdateThreadPool.emplace_back([index, embInfo, this]() { + while (true) { + receiveAndUpdateBatchIdMtx.lock(); + if (receiveAndUpdateTableBatchMap[embInfo.name] % EMBEDDING_THREAD_NUM == index) { + int cur_batch_id = receiveAndUpdateTableBatchMap[embInfo.name]; + receiveAndUpdateTableBatchMap[embInfo.name]++; + receiveAndUpdateBatchIdMtx.unlock(); + if (!isL3StorageEnabled) { + EmbeddingReceiveAndUpdateDDR(cur_batch_id, index, embInfo); + } else { + EmbeddingReceiveAndUpdateL3Storage(cur_batch_id, index, embInfo); + } + } else { + receiveAndUpdateBatchIdMtx.unlock(); + } + if (!isRunning) { + return; + } + } + }); +} + +bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, + int64_t& dims0) +{ + std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); + cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { + return (lastRecvFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + if (!isRunning) { + return false; + } + // DDR swap out key need to be removed + LookUpAndRemoveAddrs(info); + + TimeCost EmbeddingRecvTC = TimeCost(); + // finish时会pop空vector,因此需要额外判定isRunning + swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); + if (!isRunning) { + return false; + } + // 等待图执行发送d2h embedding过来 + if (info.batchId != 0) { + TransferChannel transferName = TransferChannel::D2H; + auto size = hdTransfer->RecvAcl(transferName, TRAIN_CHANNEL_ID, info.name, info.threadIdx, info.batchId); + if (size == 0) { + LOG_WARN(HOSTEMB + "recv empty data"); + return false; + } + + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[info.name][info.threadIdx], 0); + if (aclData == nullptr) { + throw runtime_error("Acl get tensor data from dataset failed."); + } + ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + + // 判断拿到的embedding个数是否与swapOutKeys个数相等 + size_t dimNum = acltdtGetDimNumFromItem(aclData); + int64_t dims[dimNum]; + acltdtGetDimsFromItem(aclData, dims, dimNum); + + LOG_DEBUG("table:{}, batchId:{}, recv d2h, dims[0]:{}, swapOutAddrs.size:{}", info.name, info.batchId, dims[0], + swapOutAddrs.size()); + dims0 = dims[0]; + } + LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name.c_str(), info.batchId, + info.threadIdx, EmbeddingRecvTC.ElapsedMS()); + lastRecvFinishStepMap[info.name]++; + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; +} + +void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, + int64_t& dims0) +{ + std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); + cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { + return (lastUpdateFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + + TimeCost EmbeddingUpdateTC = TimeCost(); + std::vector swapOutDDRAddrOffs = HBMSwapKeyQue[info.name + ADDR_STR].WaitAndPop(); + if (!isRunning) { + return; + } + uint64_t memSize = info.extEmbeddingSize * sizeof(float); + uint64_t extEmbeddingSize = info.extEmbeddingSize; + // DDR更新 +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(swapOutAddrs, swapOutDDRAddrOffs, embPtr, extEmbeddingSize, memSize) + for (uint64_t i = 0; i < swapOutAddrs.size(); i++) { + auto rc = memcpy_s(swapOutAddrs[i], memSize, embPtr + swapOutDDRAddrOffs[i] * extEmbeddingSize, memSize); + if (rc != 0) { + throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); + } + } + LOG_DEBUG("table:{}, batchId:{}, thread:{}, EmbeddingUpdateTC(ms):{}", info.name.c_str(), info.batchId, + info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); + + // L3Storage更新 + TimeCost L3StorageUpdateTC = TimeCost(); + std::vector swapOutL3StorageAddrOffs = HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].WaitAndPop(); + std::vector swapOutL3StorageKeys = HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].WaitAndPop(); + if (!isRunning) { + return; + } + + if (dims0 != static_cast(swapOutAddrs.size() + swapOutL3StorageKeys.size())) { + throw runtime_error("data dims[0] != swapOutKeys.size"); + } + cacheManager->UpdateL3StorageEmb(info.name, embPtr, extEmbeddingSize, swapOutL3StorageKeys, + swapOutL3StorageAddrOffs); + LOG_DEBUG("table:{}, batchId:{}, thread{}, L3StorageUpdateTC(ms):{}", info.name.c_str(), info.batchId, + info.threadIdx, L3StorageUpdateTC.ElapsedMS()); + + lastUpdateFinishStepMap[info.name]++; + cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); +} + +bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) +{ + std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); + cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { + return (lastUpdateFinishStepMap[info.name] >= info.batchId) || mutexDestroy; + }); + if (!isRunning) { + return false; + } + + std::unique_lock lastLookUpFinishLocker(lastLookUpFinishMutexMap[info.name][info.threadIdx]); + cvLastLookUpFinishMap[info.name][info.threadIdx].wait(lastLookUpFinishLocker, [info, this] { + return (lastLookUpFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + if (!isRunning) { + return false; + } + + TimeCost transferDDR2L3StorageTC = TimeCost(); + // DDR腾空间 + std::vector DDR2L3StorageKeys = DDRSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].WaitAndPop(); + std::vector DDR2L3StorageAddrs = DDRSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); + if (!isRunning) { + return false; + } + cacheManager->TransferDDR2L3Storage(info.name, info.extEmbeddingSize, DDR2L3StorageKeys, DDR2L3StorageAddrs); + LOG_DEBUG("table:{}, thread:{}, transferDDR2L3StorageTC(ms):{}", info.name.c_str(), info.threadIdx, + transferDDR2L3StorageTC.ElapsedMS()); + + TimeCost fetchL3StorageEmb2DDRTC = TimeCost(); + // swapInKeys中在L3Storage的挪到DDR + std::vector L3Storage2DDRKeys = DDRSwapKeyForL3StorageQue[info.name + SWAP_IN_STR].WaitAndPop(); + std::vector L3Storage2DDRAddrs = DDRSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); + if (!isRunning) { + return false; + } + cacheManager->FetchL3StorageEmb2DDR(info.name, info.extEmbeddingSize, L3Storage2DDRKeys, L3Storage2DDRAddrs); + LOG_DEBUG("table:{}, thread:{}, fetchL3StorageEmb2DDRTC(ms):{}", info.name.c_str(), info.threadIdx, + fetchL3StorageEmb2DDRTC.ElapsedMS()); + + bool isSuccess = BuildH2DEmbedding(info, h2dEmb); + if (!isSuccess) { + return false; + } + + lastLookUpFinishStepMap[info.name]++; + cvLastLookUpFinishMap[info.name][info.cvNotifyIndex].notify_all(); + + return true; +} + +void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& h2dEmb) +{ + std::unique_lock lastSendFinishLocker(lastSendFinishMutexMap[info.name][info.threadIdx]); + cvLastSendFinishMap[info.name][info.threadIdx].wait(lastSendFinishLocker, [info, this] { + return (lastSendFinishStepMap[info.name] == info.batchId) || mutexDestroy; + }); + TimeCost SendTC = TimeCost(); + hdTransfer->Send(TransferChannel::H2D, h2dEmb, TRAIN_CHANNEL_ID, info.name, info.batchId); + lastSendFinishStepMap[info.name]++; + cvLastSendFinishMap[info.name][info.cvNotifyIndex].notify_all(); + LOG_DEBUG("table:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name.c_str(), info.threadIdx, SendTC.ElapsedMS()); + + // 对于end of sequence场景,key + // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel + hybridMgmtBlock->h2dNextBatchId[info.name]++; + LOG_DEBUG("h2dNextBatchId, table:{}, next batchId:{}", info.name, hybridMgmtBlock->h2dNextBatchId[info.name]); +} + +void HybridMgmt::HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut) +{ + bool sendAllChannel = false; + if (channelId == EVAL_CHANNEL_ID) { + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } else { + // train+eval场景 + hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); + LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); + } + } + KEY_PROCESS_INSTANCE->SendEos(embName, batchId, channelId, sendAllChannel); + remainBatchOut = false; +} + +void HybridMgmt::HandleEndBatchCase(const EmbBaseInfo& info, vector& swapInPos) +{ + if ((info.channelId == TRAIN_CHANNEL_ID) && IsTrainEndBatch(info.batchId)) { + // 如果是train epoch最后一个batch,补发emptySwapOutPos以启动当前step + std::vector emptySwapOutPos; + SendTensorForSwap(info, swapInPos, emptySwapOutPos); + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; + LOG_DEBUG("handle last end batch for current epoch, table:{}, batchId:{}", info.name, info.batchId); + return; + } + + if (info.channelId == EVAL_CHANNEL_ID && IsEvalEndBatch(info.batchId)) { + // 当前step之后eval结束,需要设置处理状态 + // 因为eval、predict最后1个batch之后不会像train那样再往后跑,所以必须放这里补发 + LOG_DEBUG("reach max eval step, send emptySwapOutPos tensor for last step to finish eval, " + "change ProcessStatus to {}, table:{}, batchId:{}", + ProcessStatus2Str(ProcessStatus::AFTER_SWITCH_FIRST_BATCH), info.name, info.batchId); + std::vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; + } +} + +void HybridMgmt::HandleFirstBatchCaseDDR(const EmbBaseInfo& info, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) +{ + TimeCost swapProcessTC; + auto& swapInKeys = swapInKoPair.first; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; + + vector emptySwapOutKeys; + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKoPair.first.size(), emptySwapOutKeys.size()); + trainTestSwitchInfoStore[info.name] = {swapOutKeys, swapOutPos}; + + LOG_DEBUG("handle first batch case, delay sending swapInPos, table:{}", info.name); + LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, + info.batchId, info.channelId, swapInKeys.size(), emptySwapOutKeys.size()); + HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutKeys); + HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); +} + +void HybridMgmt::HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair) +{ + // 发现train、save、eval切换,先保存状态,发emptySwapOutKeys以对应上一步的emptySwapOutPos + vector emptySwapOutKeys; + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKoPair.first.size(), emptySwapOutKeys.size()); + trainTestSwitchInfoStore[info.name] = {swapOutKoPair.first, swapOutKoPair.second}; + + TimeCost ProcessSwapInKeysTC = TimeCost(); + vector L3StorageToDDRKeys; + vector DDRToL3StorageKeys; + cacheManager->ProcessSwapInKeys(info.name, swapInKoPair.first, DDRToL3StorageKeys, L3StorageToDDRKeys); + LOG_DEBUG("ProcessSwapInKeysTC(ms):{} ", ProcessSwapInKeysTC.ElapsedMS()); + + vector emptySwapOutDDRKeys; + vector emptySwapOutDDRAddrOffs; + vector emptySwapOutL3StorageKeys; + vector emptySwapOutL3StorageAddrOff; + + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKoPair.first.size(), swapOutKoPair.first.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapOutDDRKeys.size:{}, swapOutDDRAddrOffs.size:{}, " + "swapOutL3StorageKeys.size:{}, swapOutL3StorageAddrOff.size:{}", + info.name, info.batchId, info.channelId, emptySwapOutDDRKeys.size(), emptySwapOutDDRAddrOffs.size(), + emptySwapOutL3StorageKeys.size(), emptySwapOutL3StorageAddrOff.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDRToL3StorageKeys.size:{}, L3StorageToDDRKeys.size:{}", info.name, + info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); + + auto DDRToL3StorageKeysForL3S = DDRToL3StorageKeys; + auto L3StorageToDDRKeysForL3S = L3StorageToDDRKeys; + // DDR<->L3Storage + DDRSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(DDRToL3StorageKeys); + DDRSwapKeyQue[info.name + SWAP_IN_STR].Pushv(L3StorageToDDRKeys); + + DDRSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(DDRToL3StorageKeysForL3S); + DDRSwapKeyForL3StorageQue[info.name + SWAP_IN_STR].Pushv(L3StorageToDDRKeysForL3S); + + // HBM<->DDR + HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutDDRKeys); + HBMSwapKeyQue[info.name + ADDR_STR].Pushv(emptySwapOutDDRAddrOffs); + HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKoPair.first); + + // HBM->L3Storage + HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutL3StorageKeys); + HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(emptySwapOutL3StorageAddrOff); +} + +void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, + vector& swapOutKeys) +{ + TimeCost ProcessSwapInKeysTC; + vector L3StorageToDDRKeys; + vector DDRToL3StorageKeys; + cacheManager->ProcessSwapInKeys(info.name, swapInKeys, DDRToL3StorageKeys, L3StorageToDDRKeys); + LOG_DEBUG("ProcessSwapInKeysTC(ms):{} ", ProcessSwapInKeysTC.ElapsedMS()); + + TimeCost ProcessSwapOutKeysTC; + HBMSwapOutInfo hbmSwapInfo; + cacheManager->ProcessSwapOutKeys(info.name, swapOutKeys, hbmSwapInfo); + LOG_DEBUG("ProcessSwapOutKeysTC(ms):{} ", ProcessSwapOutKeysTC.ElapsedMS()); + + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, info.batchId, + info.channelId, swapInKeys.size(), swapOutKeys.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, swap out, HBM2DDR Keys:{}, HBM2DDR AddrOffs:{}, " + "HBM2L3Storage Keys:{}, HBM2L3Storage AddrOff:{}", + info.name, info.batchId, info.channelId, hbmSwapInfo.swapOutDDRKeys.size(), + hbmSwapInfo.swapOutDDRAddrOffs.size(), hbmSwapInfo.swapOutL3StorageKeys.size(), + hbmSwapInfo.swapOutL3StorageAddrOffs.size()); + LOG_DEBUG("table:{}, batchId:{}, channelId:{}, DDR2L3Storage Keys:{}, L3Storage2DDR Keys:{}", info.name, + info.batchId, info.channelId, DDRToL3StorageKeys.size(), L3StorageToDDRKeys.size()); + + auto DDRToL3StorageKeysForL3S = DDRToL3StorageKeys; + auto L3StorageToDDRKeysForL3S = L3StorageToDDRKeys; + // DDR<->L3Storage + DDRSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(DDRToL3StorageKeys); + DDRSwapKeyQue[info.name + SWAP_IN_STR].Pushv(L3StorageToDDRKeys); + + DDRSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(DDRToL3StorageKeysForL3S); + DDRSwapKeyForL3StorageQue[info.name + SWAP_IN_STR].Pushv(L3StorageToDDRKeysForL3S); + + // HBM<->DDR + HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(hbmSwapInfo.swapOutDDRKeys); + HBMSwapKeyQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutDDRAddrOffs); + HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); + + // HBM->L3Storage + HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(hbmSwapInfo.swapOutL3StorageKeys); + HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutL3StorageAddrOffs); +} + +bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb) +{ + std::vector swapInAddrs = HBMSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); + if (!isRunning) { + return false; + } + h2dEmb.emplace_back( + Tensor(tensorflow::DT_FLOAT, {int(swapInAddrs.size()), static_cast(info.extEmbeddingSize)})); + auto& tmpTensor = h2dEmb.back(); + float* h2dEmbAddr = tmpTensor.flat().data(); + TimeCost embeddingLookupTC = TimeCost(); + + uint64_t memSize = info.extEmbeddingSize * sizeof(float); +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(swapInAddrs, h2dEmbAddr, info, memSize) + for (uint64_t i = 0; i < swapInAddrs.size(); i++) { + auto rc = memcpy_s(h2dEmbAddr + i * info.extEmbeddingSize, memSize, swapInAddrs[i], memSize); + if (rc != 0) { + throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); + } + } + LOG_DEBUG("table:{}, thread:{}, batchId:{}, send h2dEmb, emb size:{}, emb samples:{}, embeddingLookupTC(ms):{}", + info.name.c_str(), info.threadIdx, info.batchId, swapInAddrs.size(), + FloatPtrToLimitStr(h2dEmbAddr, swapInAddrs.size() * info.extEmbeddingSize), + embeddingLookupTC.ElapsedMS()); + return true; +} + +vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut) +{ + bool isEos = false; + auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos, lookUpSwapInAddrsPushId); + if (isEos) { + HandleEosCase(info, remainBatchOut); + return uniqueKeys; + } + if (uniqueKeys.empty()) { + remainBatchOut = false; + LOG_WARN("table:{}, channelId:{} batchId:{}, UniqueKeys result is empty", info.name, info.channelId, + info.batchId); + return uniqueKeys; + } + + if (info.channelId == TRAIN_CHANNEL_ID) { + TimeCost KeyMaintainTC; + trainKeysSet[info.name].insert(uniqueKeys.begin(), uniqueKeys.end()); + LOG_DEBUG("table:{}, batchId:{}, KeyMaintainTC(ms):{}", info.name, info.batchId, KeyMaintainTC.ElapsedMS()); + } else { + for (auto& key : uniqueKeys) { + if (trainKeysSet[info.name].find(key) == trainKeysSet[info.name].end()) { + key = INVALID_KEY_VALUE; + LOG_TRACE("find key not train before, set as invalid key"); + } + } + } + + LOG_DEBUG("table:{}, channelId:{} batchId:{}, GetUniqueKeys end", info.name, info.channelId, info.batchId); + return uniqueKeys; +} + +vector HybridMgmt::GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut) +{ + auto restoreVecSec = KEY_PROCESS_INSTANCE->GetRestoreVecSec(info); + if (restoreVecSec.empty()) { + remainBatchOut = false; + LOG_WARN("table:{}, channelId:{} batchId:{}, restoreVecSec result is empty", info.name, info.channelId, + info.batchId); + return restoreVecSec; + } + LOG_DEBUG("table:{}, channelId:{} batchId:{}, GetRestoreVecSec end", info.name, info.channelId, info.batchId); + return restoreVecSec; +} + +void HybridMgmt::SendAll2AllVec(const EmbBaseInfo& info, bool& remainBatchOut) +{ + if (!mgmtRankInfo.useStatic) { + bool isEos = false; // useless, adapt to HBM mode + TimeCost getAll2AllTC; + unique_ptr> all2all = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::ALL2ALL, isEos); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, GetInfoVec all2all end, GetAll2AllTC(ms):{}", info.name, + info.channelId, info.batchId, getAll2AllTC.ElapsedMS()); + if (all2all == nullptr) { + remainBatchOut = false; + LOG_WARN("Information vector is nullptr!"); + return; + } + TimeCost sendAll2AllTC; + hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send all2all end, sendAll2AllTC(ms):{}", info.name, + info.channelId, info.batchId, sendAll2AllTC.ElapsedMS()); + } +} + +void HybridMgmt::SendRestoreVec(const EmbBaseInfo& info, bool& remainBatchOut) +{ + bool isEos = false; // useless, adapt to HBM mode + TimeCost getRestoreTC; + unique_ptr> infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::RESTORE, isEos); + if (infoVecs == nullptr) { + remainBatchOut = false; + if (isRunning) { + LOG_ERROR("Information vector is nullptr!"); + } + return; + } + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, get restore end, getRestoreTC(ms):{}", info.name, info.channelId, + info.batchId, getRestoreTC.ElapsedMS()); + + TimeCost sendRestoreSyncTC; + hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send restore end, sendRestoreSyncTC(ms):{}", info.name, + info.channelId, info.batchId, sendRestoreSyncTC.ElapsedMS()); +} + +void HybridMgmt::SendLookupOffsets(const EmbBaseInfo& info, vector& uniqueKeys, + vector& restoreVecSec) +{ + // uniqueKeys already transfer to offset in GetSwapPairsAndKey2Offset + // graph will filter out invalid offset(-1). see function _set_specific_value_for_non_valid_key + TimeCost sendLookupOffsetsTC; + std::vector lookupOffsets; + for (const auto& index : restoreVecSec) { + if (index == INVALID_INDEX_VALUE) { + lookupOffsets.emplace_back(static_cast(INVALID_KEY_VALUE)); + continue; + } + lookupOffsets.emplace_back(uniqueKeys[index]); + } + hdTransfer->Send(TransferChannel::LOOKUP, {Vec2TensorI32(lookupOffsets)}, info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, send lookupOffset, sendLookupOffsetsTC(ms):{}", info.name, + info.channelId, info.batchId, sendLookupOffsetsTC.ElapsedMS()); +} + +void HybridMgmt::SendGlobalUniqueVec(const EmbBaseInfo& info, vector& uniqueKeys, + vector& restoreVecSec) +{ + if (!(info.channelId == TRAIN_CHANNEL_ID && mgmtRankInfo.useSumSameIdGradients)) { + return; + } + TimeCost sendUniqueKeysSyncTC; + hdTransfer->Send(TransferChannel::UNIQKEYS, + {mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys)}, + info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, sendUniqueKeysSyncTC(ms):{}", info.name, info.channelId, + info.batchId, sendUniqueKeysSyncTC.ElapsedMS()); + + TimeCost sendRestoreVecSecSyncTC; + hdTransfer->Send(TransferChannel::RESTORE_SECOND, {Vec2TensorI32(restoreVecSec)}, info.channelId, info.name); + LOG_DEBUG("table:{}, channelId:{}, batchId:{}, sendRestoreVecSecSyncTC(ms):{}", info.name, info.channelId, + info.batchId, sendRestoreVecSecSyncTC.ElapsedMS()); +} + +bool HybridMgmt::HandleSpecialProcessStatusDDR(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) +{ + TimeCost swapProcessTC; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; + + if (specialProcessStatus[info.name] == ProcessStatus::AFTER_SWITCH_FIRST_BATCH) { + // 发现train、save、eval切换,先保存状态,发emptySwapOutKeys以对应上一步的emptySwapOutPos + HandleFirstBatchCaseDDR(info, swapInKoPair, swapOutKoPair); + LOG_DEBUG("handle channel switch case:afterSwitchFirstBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); + + if (mgmtRankInfo.ctrlSteps[info.channelId] == 1) { + vector emptySwapOutPos; + SendTensorForSwap(info, swapInPos, emptySwapOutPos); + LOG_DEBUG("ProcessEmbInfoDDR special case, user only run one step, table:{}, channelId:{}, batchId:{}", + info.name, info.channelId, info.batchId); + return true; + } + + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_SECOND_BATCH; + LOG_DEBUG("ProcessEmbInfoDDR end, table:{}, batchId:{}, swapProcessTC(ms):{}, getAndSendTensorsTC(ms):{}", + info.name, info.batchId, swapProcessTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS()); + return true; + } + if (specialProcessStatus[info.name] == ProcessStatus::AFTER_SWITCH_SECOND_BATCH) { + // 将上一步暂存的状态合并至当前step一起处理 + auto tempStore = trainTestSwitchInfoStore[info.name]; + swapOutKeys.insert(swapOutKeys.end(), tempStore[0].begin(), tempStore[0].end()); + swapOutPos.insert(swapOutPos.end(), tempStore[1].begin(), tempStore[1].end()); + specialProcessStatus[info.name] = ProcessStatus::NORMAL; + LOG_DEBUG("handle channel switch case:afterSwitchSecondBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); + } + return false; +} + +bool HybridMgmt::HandleSpecialProcessStatusL3Storage(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) +{ + TimeCost swapProcessTC; + auto& swapInPos = swapInKoPair.second; + auto& swapOutKeys = swapOutKoPair.first; + auto& swapOutPos = swapOutKoPair.second; + + if (specialProcessStatus[info.name] == ProcessStatus::AFTER_SWITCH_FIRST_BATCH) { + // 发现train、save、eval切换,先保存状态,发emptySwapOutKeys以对应上一步的emptySwapOutPos + HandleFirstBatchCaseL3Storage(info, swapInKoPair, swapOutKoPair); + LOG_DEBUG("handle channel switch case:afterSwitchFirstBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); + + if (mgmtRankInfo.ctrlSteps[info.channelId] == 1) { + vector emptySwapOutPos; + SendTensorForSwap(info, swapInPos, emptySwapOutPos); + LOG_DEBUG("ProcessEmbInfoL3Storage special case, user only run one step, " + "table:{}, channelId:{}, batchId:{}", + info.name, info.channelId, info.batchId); + } + + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_SECOND_BATCH; + LOG_DEBUG("ProcessEmbInfoL3Storage end, table:{}, batchId:{}, swapProcessTC(ms):{}, getAndSendTensorsTC(ms):{}", + info.name, info.batchId, swapProcessTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS()); + return true; + } + if (specialProcessStatus[info.name] == ProcessStatus::AFTER_SWITCH_SECOND_BATCH) { + // 将上一步暂存的状态合并至当前step一起处理 + auto tempStore = trainTestSwitchInfoStore[info.name]; + swapOutKeys.insert(swapOutKeys.end(), tempStore[0].begin(), tempStore[0].end()); + swapOutPos.insert(swapOutPos.end(), tempStore[1].begin(), tempStore[1].end()); + specialProcessStatus[info.name] = ProcessStatus::NORMAL; + LOG_DEBUG("handle channel switch case:afterSwitchSecondBatch, table:{}, channelId:{}, batchId:{}", info.name, + info.channelId, info.batchId); + } + return false; +} + +void HybridMgmt::CheckLookupAddrSuccessDDR() +{ + if (!lookupAddrSuccess) { + // lookup失败,从future捞出异常 + for (auto& t : lookUpSwapInAddrsThreads) { + t.get(); + } + for (auto& t : lookUpSwapOutAddrsThreads) { + t.get(); + } + } +} + +void HybridMgmt::GetSwapPairsAndKey2Offset(const EmbBaseInfo& info, vector& uniqueKeys, + pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) +{ + TimeCost GetSwapPairsAndKey2OffsetTC; + int swapInCode = embCache->GetSwapPairsAndKey2Offset(info.name, uniqueKeys, swapInKoPair, swapOutKoPair); + if (swapInCode != H_OK) { + string errMsg = + StringFormat("table:%s, GetSwapPairsAndKey2Offset failed! error code:%d", info.name.c_str(), swapInCode); + throw runtime_error(errMsg); + } + LOG_DEBUG("table:{}, channel:{}, batchId:{}, GetSwapPairsAndKey2OffsetTC(ms):{}", info.name, info.channelId, + info.batchId, GetSwapPairsAndKey2OffsetTC.ElapsedMS()); + + LOG_DEBUG("table:{}, channel:{}, batchId:{}, swapIn keys:{}, swapIn pos:{}, swapOut keys:{}, swapOut pos:{}", + info.name, info.channelId, info.batchId, VectorToString(swapInKoPair.first), + VectorToString(swapInKoPair.second), VectorToString(swapOutKoPair.first), + VectorToString(swapOutKoPair.second)); +} + +void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, vector>& swapInKoPair, + pair, vector>& swapOutKoPair) +{ + auto& swapInKeys = swapInKoPair.first; + auto& swapOutKeys = swapOutKoPair.first; + + LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, + info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); + HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(swapOutKeys); + HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); + + CheckLookupAddrSuccessDDR(); +} + +bool HybridMgmt::IsTrainAndEvalCase() +{ + bool isChannelSwitchCase = false; + for (auto& i : mgmtEmbInfo) { + if (specialProcessStatus[i.name] == ProcessStatus::AFTER_SWITCH_FIRST_BATCH) { + isChannelSwitchCase = true; + break; + } + } + return alreadyTrainOnce && isChannelSwitchCase; +} + +void HybridMgmt::BackUpTrainStatus() +{ + int channelID = TRAIN_CHANNEL_ID; + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[channelID]; + if (theTrainBatchId == 0) { + return; + } + + LOG_INFO("On Estimator train and eval mode, start to backup train status, " + "current train batchId: {} .", theTrainBatchId); + // When in the train and eval mode of estimator, backup training states before loading. + EmbeddingMgmt::Instance()->BackUpTrainStatusBeforeLoad(); + + if (isL3StorageEnabled) { + cacheManager->BackUpTrainStatus(); + } + isBackUpTrainStatus = true; +} + +void HybridMgmt::RecoverTrainStatus() +{ + if (isBackUpTrainStatus) { + EmbeddingMgmt::Instance()->RecoverTrainStatus(); + } + + if (isL3StorageEnabled) { + cacheManager->RecoverTrainStatus(); + } + isBackUpTrainStatus = false; +} \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 0251eb91f8ad833f9efea0b02c95511a9421c9dc..5f94c96dafd99f411aad51e981dda2a0045d014a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -17,143 +17,305 @@ See the License for the specific language governing permissions and #define MX_REC_EMB_MGMT_H #include -#include #include +#include +#include #include "absl/container/flat_hash_map.h" - -#include "utils/common.h" -#include "utils/config.h" - -#include "host_emb/host_emb.h" -#include "emb_hashmap/emb_hashmap.h" +#include "emb_table/embedding_table.h" #include "hd_transfer/hd_transfer.h" -#include "ssd_cache/cache_manager.h" #include "hybrid_mgmt_block.h" -#include "emb_table/embedding_table.h" +#include "l3_storage/cache_manager.h" +#include "ock_ctr_common/include/embedding_cache.h" +#include "ock_ctr_common/include/error_code.h" +#include "ock_ctr_common/include/factory.h" +#include "utils/common.h" +#include "utils/config.h" +#include "utils/singleton.h" +#include "utils/task_queue.h" +#include "utils/time_cost.h" namespace MxRec { - using namespace std; - using namespace tensorflow; - - enum class TaskType { - HBM, - DDR - }; - - class HybridMgmt { - public: - HybridMgmt() = default; - - ~HybridMgmt() - { - if (isRunning) { - Destroy(); - } +using namespace std; +using namespace tensorflow; +using namespace Common; + +enum class TaskType { + HBM, + DDR +}; + +enum class ProcessStatus { + NORMAL, + AFTER_SWITCH_FIRST_BATCH, + AFTER_SWITCH_SECOND_BATCH +}; + +inline string ProcessStatus2Str(ProcessStatus s) +{ + switch (s) { + case ProcessStatus::NORMAL: + return "normal"; + case ProcessStatus::AFTER_SWITCH_FIRST_BATCH: + return "afterSwitchFirstBatch"; + case ProcessStatus::AFTER_SWITCH_SECOND_BATCH: + return "afterSwitchSecondBatch"; + default: + throw std::invalid_argument("Invalid ProcessStatus"); + } +}; + +struct EmbTaskInfo { + int batchId; + int threadIdx; + int cvNotifyIndex; + int extEmbeddingSize; + string name; +}; + +class HybridMgmt { +public: + HybridMgmt() = default; + + ~HybridMgmt() + { + if (isRunning) { + Destroy(); } + } + + HybridMgmt(const HybridMgmt&) = delete; + + HybridMgmt& operator=(const HybridMgmt&) = delete; + + bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, + const vector& thresholdValues, bool ifLoad); + + void Save(const string& savePath); + + bool Load(const string& loadPath, vector warmStartTables); + + OffsetT SendHostMap(const string tableName); + + OffsetT SendLoadMap(const string tableName); + + void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); + + void Start(); + + void StartThreadForHBM(); + + void StartThreadForDDR(); + + void Destroy(); + + bool ParseKeys(int channelId, int& batchId, TaskType type); + + bool Evict(); + + void NotifyBySessionRun(int channelID) const; + + void CountStepBySessionRun(int channelID, int steps) const; + + int64_t GetTableSize(const string& embName) const; + + int64_t GetTableCapacity(const string& embName) const; + + void SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo) const; + + void FetchDeviceEmb(); + + void ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut, bool isGrad); + + void ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut); + + void ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBatchOut); + + void BackUpTrainStatus(); + + void RecoverTrainStatus(); + + GTEST_PRIVATE : bool mutexDestroy{false}; + std::mutex lookUpAndSendBatchIdMtx; + std::mutex receiveAndUpdateBatchIdMtx; + std::map lookUpAndSendTableBatchMap; + std::map receiveAndUpdateTableBatchMap; + + std::map> lastUpdateFinishMutexMap; + std::map> cvLastUpdateFinishMap; + std::map lastUpdateFinishStepMap; + std::map> lastLookUpFinishMutexMap; + std::map> cvLastLookUpFinishMap; + std::map lastLookUpFinishStepMap; + std::map> lastSendFinishMutexMap; + std::map> cvLastSendFinishMap; + std::map lastSendFinishStepMap; + std::map> lastRecvFinishMutexMap; + std::map> cvLastRecvFinishMap; + std::map lastRecvFinishStepMap; + + std::vector EmbeddingLookUpAndSendThreadPool; + std::vector EmbeddingReceiveAndUpdateThreadPool; + std::vector> lookUpSwapOutAddrsThreads; + std::vector> lookUpSwapInAddrsThreads; + + std::map>> HBMSwapKeyQue; + std::map>> HBMSwapKeyForL3StorageQue; + std::map>> DDRSwapKeyQue; + std::map>> DDRSwapKeyForL3StorageQue; + std::map>> HBMSwapAddrsQue; + std::map>> DDRSwapAddrsQue; + + std::mutex evictMut; + + std::map> trainKeysSet; + const string SWAP_IN_STR = "SwapIn"; + const string SWAP_OUT_STR = "SwapOut"; + + const string ADDR_STR = "Addr"; + ock::ctr::EmbCacheManagerPtr embCache = nullptr; + std::map> lastSwapInPosMap{}; + std::map>> trainTestSwitchInfoStore{}; + std::atomic lookupAddrSuccess{true}; + + std::mutex saveMutex; + std::condition_variable cvCheckSave; + + void SetFeatureTypeForLoad(vector& loadFeatures); + + void EvictKeys(const string& embName, const vector& keys); + + void InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const; + + void EvictL3StorageKeys(const string& embName, const vector& keys) const; + + void LookUpAndRemoveAddrs(const EmbTaskInfo& info); // L3Storage, synchronous + + void LookUpSwapAddrs(const std::string& embName); // DDR, asynchronous + + void EmbeddingTask(); + + void MultiThreadEmbHDTransWrap(); + + void EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo& embInfo); + + void EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbInfo& embInfo); + + void EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo); + + void EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, const EmbInfo& embInfo); + + void SendTensorForSwap(const EmbBaseInfo& info, const vector& swapInPosUint, + const vector& swapOutPosUint); - HybridMgmt(const HybridMgmt&) = delete; +private: + HybridMgmtBlock* hybridMgmtBlock; + vector mgmtEmbInfo; + RankInfo mgmtRankInfo; + CacheManager* cacheManager; + vector> procThreads{}; + map> evictKeyMap{}; + HDTransfer* hdTransfer; + OffsetMapT offsetMapToSend; + OffsetMapT loadOffsetToSend; + bool isL3StorageEnabled{false}; + bool isRunning; + bool isLoad{false}; + bool isInitialized{false}; + bool alreadyTrainOnce = false; // 用于判断是否为predict模式 + bool isBackUpTrainStatus = false; // whether the train state has been backed up + map lookUpSwapInAddrsPushId; // 用于处理eos场景,当消费者追上生产者且长时间无上游数据,会触发eos + map specialProcessStatus; - HybridMgmt& operator=(const HybridMgmt&) = delete; + void TrainTask(TaskType type); - bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, - const vector& thresholdValues, bool ifLoad); + void EvalTask(TaskType type); - bool Save(const string savePath); + void SendUniqKeysAndRestoreVecHBM(const EmbBaseInfo& info, const unique_ptr>& infoVecs, + bool isGrad) const; - bool Load(const string& loadPath); + void HandleEndBatchCase(const EmbBaseInfo& info, vector& swapInPos); - OffsetT SendHostMap(const string tableName); + bool IsTrainEndBatch(int batchId) const; - OffsetT SendLoadMap(const string tableName); + bool IsEvalEndBatch(int batchId) const; - void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); + void InitEmbeddingCache(const vector& embInfos); - void Start(); + void InitDataPipelineForDDR(const string& embName); - void StartThreadForHBM(); + void InitDataPipelineForL3Storage(const string& embName, int extEmbeddingSize); - void StartThreadForDDR(); + void JoinEmbeddingCacheThread(); - void Destroy(); + void HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBatchOut); - bool ParseKeys(int channelId, int& batchId); + void HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut); - bool ParseKeysHBM(int channelId, int& batchId); + void HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut); - bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, bool& remainBatchOut); + bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); - void EmbHDTrans(const int channelId, const int batchId); + void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); - bool Evict(); + bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb); - void NotifyBySessionRun(int channelID) const; + void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); - void CountStepBySessionRun(int channelID, int steps) const; + bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0); - int64_t GetTableSize(const string& embName) const; + void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0); - int64_t GetTableCapacity(const string& embName) const; + bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb); - void SetOptimizerInfo(const string& embName, OptimizerInfo optimInfo) const; + void EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& h2dEmb); - GTEST_PRIVATE: + void CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo); - void SetFeatureTypeForLoad(vector& loadFeatures); + void CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo); - bool IsLoadDataMatches(const EmbMemT& loadHostEmbs, const EmbInfo& setupHostEmbs, size_t& embTableCount) const; + void HandleFirstBatchCaseDDR(const EmbBaseInfo& info, std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void EvictKeys(const string& embName, const vector& keys); + void HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const; + void HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, + vector& swapOutKeys); - void EvictSSDKeys(const string& embName, const vector& keys) const; + bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); - void PrepareDDRData(std::shared_ptr table, - const vector &keys, int channelId, int batchId) const; + vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut); - int GetStepFromPath(const string& loadPath) const; + vector GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut); - static void AddCacheManagerTraceLog(CkptData& saveData); + void SendAll2AllVec(const EmbBaseInfo& info, bool& remainBatchOut); - void RestoreFreq4Save(CkptData& saveData) const; - private: - int currentBatchId; - int trainBatchId = 0; // 0-199, 200- - int getInfoBatchId; // 0-199, 200- - int sendBatchId; - HybridMgmtBlock* hybridMgmtBlock; - vector mgmtEmbInfo; - RankInfo mgmtRankInfo; - CacheManager* cacheManager; - HostEmb* hostEmbs {}; - unique_ptr hostHashMaps {}; - vector> procThreads {}; - map> evictKeyMap {}; - HDTransfer *hdTransfer; - OffsetMapT offsetMapToSend; - OffsetMapT loadOffsetToSend; - bool isSSDEnabled { false }; - bool isRunning; - bool isLoad { false }; - bool isInitialized { false }; + void SendRestoreVec(const EmbBaseInfo& info, bool& remainBatchOut); - void TrainTask(TaskType type); + void SendLookupOffsets(const EmbBaseInfo& info, vector& uniqueKeys, vector& restoreVecSec); - void EvalTask(TaskType type); + void SendGlobalUniqueVec(const EmbBaseInfo& info, vector& uniqueKeys, vector& restoreVecSec); - bool EndBatch(int batchId, int channelId) const; + bool HandleSpecialProcessStatusDDR(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void EmbHDTransWrap(int channelId, const int& batchId, int start); + bool HandleSpecialProcessStatusL3Storage(const EmbBaseInfo& info, TimeCost& getAndSendTensorsTC, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - bool LoadMatchesDDRSetup(const CkptData& loadData); + void CheckLookupAddrSuccessDDR(); - void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) const; + void GetSwapPairsAndKey2Offset(const EmbBaseInfo& info, vector& uniqueKeys, + std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void SendUniqKeysAndRestoreVecHBM(int channelId, int& batchId, const EmbInfo &embInfo, - const unique_ptr> &infoVecs) const; + void EnqueueSwapInfo(const EmbBaseInfo& info, std::pair, vector>& swapInKoPair, + std::pair, vector>& swapOutKoPair); - void SendUniqKeysAndRestoreVecDDR(const string &embName, int &batchId, int &channelId, DDRParam &ddrParam); - }; -} -#endif // MX_REC_EMB_MGMT_H + bool IsTrainAndEvalCase(); +}; +} // namespace MxRec +#endif // MX_REC_EMB_MGMT_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index ad10bac4468200747a464fb482a1982178b88297..04433469fff5da40870de274bed445b891f62759 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -40,6 +40,7 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "blocking by save saveInterval {} pythonBatchId {} hybridBatchId {}", saveInterval, pythonBatchId[channelId], hybridBatchId[channelId]); isBlock[TRAIN_CHANNEL_ID] = true; + finishSave = false; } if (stepsInterval[channelId] == -1) { return; @@ -74,7 +75,8 @@ bool HybridMgmtBlock::WaitValid(int channelId) { // 等待hybrid处理完成 int reTryNumber = 100; - LOG_INFO(HYBRID_BLOCKING + "check step invalid, wait {} {}", channelId, hybridBatchId[channelId]); + LOG_INFO(HYBRID_BLOCKING + "validate step and wait, channel:{}, pythonBatchId:{}, hybridBatchId:{}", + channelId, pythonBatchId[channelId], hybridBatchId[channelId]); // 等待hybrid处理完成后再一次唤醒 while (pythonBatchId[lastRunChannelId] != hybridBatchId[lastRunChannelId] and isRunning) { std::this_thread::sleep_for(std::chrono::milliseconds(10ms)); @@ -85,6 +87,8 @@ bool HybridMgmtBlock::WaitValid(int channelId) } if (pythonBatchId[channelId] == hybridBatchId[channelId]) { + LOG_ERROR(HYBRID_BLOCKING + "step not equal, channel:{}, pythonBatchId:{}, hybridBatchId:{}", + channelId, pythonBatchId[channelId], hybridBatchId[channelId]); return true; } else { // 如果等待python侧处理较长时间后hybrid依旧无法追赶上python则异常 @@ -109,18 +113,21 @@ void HybridMgmtBlock::CheckValid(int channelId) } // 当python侧第一次调用时,此时跳过参数检查 if (lastRunChannelId == -1) { - LOG_DEBUG(HYBRID_BLOCKING + "The data channel was called for the first time, and the parameters were " - "checked to be normal channelId {} hybridBatchId {}", channelId, hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + + "The data channel was called for the first time, and the parameters were " + "checked to be normal channelId {} hybridBatchId {}.", + channelId, hybridBatchId[channelId]); lastRunChannelId = channelId; return; } + // 在通道切换时,hybrid预处理的batch与python的一致。 if (pythonBatchId[lastRunChannelId] == hybridBatchId[lastRunChannelId]) { LOG_DEBUG(HYBRID_BLOCKING + - "HybridMgmt is switching data channels and checking for normal parameters. he number of steps " - "in the previous round is lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + "HybridMgmt is switching data channels and checking for normal parameters. The number of steps " + "in the previous round is lastRunChannelId {} pythonBatchId {} hybridBatchId {}.", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); } else if (pythonBatchId[lastRunChannelId] < hybridBatchId[lastRunChannelId]) { // 在通道切换时,上一个通道处理的数据超出了python侧的调用 if (rankInfo.isDDR and !WaitValid(lastRunChannelId)) { @@ -129,10 +136,10 @@ void HybridMgmtBlock::CheckValid(int channelId) } else { // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 LOG_INFO(HYBRID_BLOCKING + - "When switching data channels, it was found that HybridMgmt processed less data than the " - "Python side.In this case, after reading the dataset, the Python side called it again, but it was " - "interrupted midway,which did not affect the subsequent calls lastRunChannelId {} hybridBatchId {}", - lastRunChannelId, hybridBatchId[lastRunChannelId]); + "When switching data channels, it was found that HybridMgmt processed less data than the " + "Python side.In this case, after reading the dataset, the Python side called it again, but it was " + "interrupted midway,which did not affect the subsequent calls lastRunChannelId {} hybridBatchId {}", + lastRunChannelId, hybridBatchId[lastRunChannelId]); } lastRunChannelId = channelId; } @@ -143,7 +150,7 @@ void HybridMgmtBlock::DoBlock(int channelId) { // 通道没有切换,不用处理 LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt starts blocking channelId {} hybridBatchId {}", - channelId, hybridBatchId[channelId]); + channelId, hybridBatchId[channelId]); while (isBlock[channelId]) { std::this_thread::sleep_for(SLEEP_MS); @@ -152,20 +159,28 @@ void HybridMgmtBlock::DoBlock(int channelId) } } LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is starting to wake up channelId {} hybridBatchId {}", - channelId, hybridBatchId[channelId]); + channelId, hybridBatchId[channelId]); } /// 重置所有的步数,主要用于图重构的情况,readembedkey算子重建 /// \param channelId channelId train 0 eval 1 void HybridMgmtBlock::ResetAll(int channelId) { - LOG_DEBUG(HYBRID_BLOCKING + "Hybridmgmt is resetting data channelId {} hybridBatchId {}", - channelId, hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + "start reset block status," + " channelId:{}, pythonBatchId:{}, readEmbedBatchId:{}, hybridBatchId:{}", + channelId, pythonBatchId[channelId], readEmbedBatchId[channelId], hybridBatchId[channelId]); readEmbedBatchId[channelId] = 0; pythonBatchId[channelId] = 0; hybridBatchId[channelId] = 0; isBlock[channelId] = false; + if (channelId == EVAL_CHANNEL_ID) { + evalBatchIdTotal += readEmbedBatchId[channelId]; + } + + LOG_DEBUG(HYBRID_BLOCKING + "after reset block status," + " channelId:{}, pythonBatchId:{}, readEmbedBatchId:{}, hybridBatchId:{}", + channelId, pythonBatchId[channelId], readEmbedBatchId[channelId], hybridBatchId[channelId]); LOG_DEBUG("Start to reset isNeedSendEos"); Singleton::GetInstance()->SetEos(0, channelId); @@ -178,24 +193,24 @@ int HybridMgmtBlock::CheckSaveEmbMapValid() // 检查数据通道此时的HashMap是否被提前处理了 if (pythonBatchId[lastRunChannelId] >= hybridBatchId[lastRunChannelId]) { LOG_DEBUG(HYBRID_BLOCKING + - "HybridMgmt is checking the step and checking that the parameters are normal. " - "The number of steps in the previous round is " - "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + "HybridMgmt is checking the step and checking that the parameters are normal. " + "The number of steps in the previous round is " + "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); return 0; } else if (pythonBatchId[lastRunChannelId] + 1 == hybridBatchId[lastRunChannelId]) { // 在通道切换时,上一个通道处理的数据超出了python侧的调用 LOG_DEBUG(HYBRID_BLOCKING + - "HybridMgmt is checking the step, and the parameters have been processed one step " - "in advance. The number of steps in the previous round was " - "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + "HybridMgmt is checking the step, and the parameters have been processed one step " + "in advance. The number of steps in the previous round was " + "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); return 1; } else { // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 LOG_DEBUG(HYBRID_BLOCKING + "ERROR FLAG lastRunChannelId {} hybridBatchId {}", - lastRunChannelId, hybridBatchId[lastRunChannelId]); + lastRunChannelId, hybridBatchId[lastRunChannelId]); return -1; } } @@ -224,16 +239,37 @@ void HybridMgmtBlock::SetRankInfo(RankInfo ri) this->stepsInterval[TRAIN_CHANNEL_ID] = ri.ctrlSteps[TRAIN_CHANNEL_ID]; this->stepsInterval[EVAL_CHANNEL_ID] = ri.ctrlSteps[EVAL_CHANNEL_ID]; this->saveInterval = ri.ctrlSteps[SAVE_STEP_INDEX]; + this->maxTrainStep = ri.ctrlSteps[MAX_TRAIN_STEP_INDEX]; this->rankInfo = ri; -}; +} void HybridMgmtBlock::SetStepInterval(int trainStep, int evalStep) { this->stepsInterval[0] = trainStep; this->stepsInterval[1] = evalStep; -}; +} HybridMgmtBlock::~HybridMgmtBlock() { Destroy(); } + +void HybridMgmtBlock::Wake(int channelId) +{ + isBlock[channelId] = false; +} + +bool HybridMgmtBlock::IsNeedWaitSave() +{ + if (saveInterval != 0 && saveInterval != -1 && + hybridBatchId[TRAIN_CHANNEL_ID] % saveInterval == 0 + && !finishSave) { + return true; + } + return false; +} + +void HybridMgmtBlock::FinishSave() +{ + finishSave = true; +} \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index 00cdc73ec5769154b1904a95260e3a6c5de7bdb8..f3ee6e8fe57eac8df4fdee5bf97f6fc87abcbc8c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -26,11 +26,13 @@ See the License for the specific language governing permissions and namespace MxRec { const std::string HYBRID_BLOCKING = "[HYBRID_BLOCKING] "; const int SAVE_STEP_INDEX = 2; + const int MAX_TRAIN_STEP_INDEX = 3; const std::chrono::milliseconds SLEEP_MS = 20ms; class HybridMgmtBlock { public: HybridMgmtBlock() = default; + // 上一次运行的通道ID int lastRunChannelId = -1; // hybrid将要处理的batch id @@ -39,6 +41,13 @@ namespace MxRec { int pythonBatchId[2] = {0, 0}; // readEmbed算子侧将要处理的batch id int readEmbedBatchId[2] = {0, 0}; + // eval通道处理过的batch计数,不区分通道、图,不会重置;用于判断h2d swap是否需要eos + int evalBatchIdTotal = 0; + int maxTrainStep = 0; + int stepsInterval[2] = {0, 0}; // 通道i运行多少步后切换为通道j + + // hybrid已完成H2D的step;不区分通道、图,不会重置; + map h2dNextBatchId; int loop[2] = {1, 1}; @@ -76,21 +85,26 @@ namespace MxRec { void Destroy(); + void Wake(int channelId); + + bool IsNeedWaitSave(); + + void FinishSave(); + private: - // 通道i运行多少步后切换为通道j - int stepsInterval[2] = {0, 0}; // 控制通道阻塞的变量 bool isBlock[2] = {true, true}; // 控制训练了多少步进行保存的步数 int saveInterval = 0; RankInfo rankInfo; + bool finishSave = true; }; class HybridMgmtBlockingException : public std::exception { public: explicit HybridMgmtBlockingException(const string scene) { - HybridMgmtBlock *hybridMgmtBlock = Singleton::GetInstance(); + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); int channelId = hybridMgmtBlock->lastRunChannelId; int preprocessBatchNumber = hybridMgmtBlock->hybridBatchId[channelId]; int currentBatchNumber = hybridMgmtBlock->pythonBatchId[channelId]; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 1ea0084f93fac55a7226f435336fba9932937fe5..addc464764fc244164e99a8c46506d120f663da1 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -20,11 +20,10 @@ See the License for the specific language governing permissions and using namespace MxRec; RandomNormalInitializer::RandomNormalInitializer(int start, int len, NormalInitializerInfo& initInfo) - : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed), + initParam(initInfo.initK), generator(std::default_random_engine(seed)), + distribution(std::normal_distribution(mean, stddev)) { - initParam = initInfo.initK; - generator = std::default_random_engine(seed); - distribution = std::normal_distribution(mean, stddev); } void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index 9d5f9942fb03addf961acffd9dc30ed8ae5b44ff..e342f75f60d5e0e67785843bf5fadae2555047b1 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -37,6 +37,7 @@ namespace MxRec { float mean; float stddev; int seed; + float initParam; std::default_random_engine generator; std::normal_distribution distribution; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index d50a7a97eb909f924bc2cd9216f32d7b92f99168..e011cfc723bbdd28296539e35126ce3d63a3c08b 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -20,7 +20,8 @@ See the License for the specific language governing permissions and using namespace MxRec; TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, NormalInitializerInfo& initInfo) - : start(start), len(len), seed(initInfo.seed) + : start(start), len(len), seed(initInfo.seed), generator(std::default_random_engine(initInfo.seed)), + distribution(std::normal_distribution(initInfo.mean, initInfo.stddev)) { initParam = initInfo.initK; // 校验stddev mean值范围 @@ -43,7 +44,6 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, Norma stddev = initInfo.stddev; } - generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); minBound = initParam * (mean - static_cast(boundNum) * stddev); maxBound = initParam * (mean + static_cast(boundNum) * stddev); diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index fe7295b2b81c678c603eb587f1186adbea62834f..0305665a7887c043a2fa52955e7b41d69bac9a43 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -144,7 +144,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con } // 特征淘汰接口 -void FeatureAdmitAndEvict::FeatureEvict(map>& evictKeyMap) +void FeatureAdmitAndEvict::FeatureEvict(map>& evictKeyMap) { std::vector tableNames = GetAllNeedEvictTableNames(); if (tableNames.empty()) { @@ -163,7 +163,7 @@ void FeatureAdmitAndEvict::FeatureEvict(map> } } -void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::vector& evictKey) +void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::vector& evictKey) { // 从 m_historyRecords 中淘汰删除 time_t currTime = m_recordsData.timestamps[embName]; diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 0b31b080c5dfc82d6c543136d22ba6c8f8dfe507..6c82c84620d91889ab3f317ff69104e9ef8b399e 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -25,7 +25,6 @@ See the License for the specific language governing permissions and #include #include #include "absl/container/flat_hash_map.h" -#include "host_emb/host_emb.h" #include "utils/common.h" #include "utils/safe_queue.h" #include "utils/singleton.h" @@ -69,7 +68,7 @@ namespace MxRec { KeysT& splitKey, std::vector& keyCount); // 特征淘汰接口 - void FeatureEvict(map>& evictKeyMap); + void FeatureEvict(map>& evictKeyMap); void ExecuteFeatureAdmit( const string& tableName, int channel, KeysT& splitKey, absl::flat_hash_map& mergeKeys); @@ -105,7 +104,7 @@ namespace MxRec { std::vector GetAllNeedEvictTableNames(); FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tableNameOrigin, const int64_t featureId, const uint32_t featureCnt); - void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); + void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); bool m_isEnableFunction { true }; // “特征淘汰”的使能开关 diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index f76f69078140268e7f9b37318c1b01f05840b167..1cb9f992533e898484e0a08f64819e0bd1b70416 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -24,7 +24,6 @@ See the License for the specific language governing permissions and #include "utils/singleton.h" #include "utils/time_cost.h" #include "utils/config.h" -#include "host_emb/host_emb.h" #include "emb_table/embedding_mgmt.h" #include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/error_code.h" @@ -44,23 +43,20 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos const vector& thresholdValues, int seed) { + readySendEosCnt[TRAIN_CHANNEL_ID].store(0); + readySendEosCnt[EVAL_CHANNEL_ID].store(0); + finishSendEosCnt[TRAIN_CHANNEL_ID].store(0); + finishSendEosCnt[EVAL_CHANNEL_ID].store(0); + this->rankInfo = rInfo; - if (rankInfo.useHot) { - SetupHotEmbUpdateStep(); - } + + SetupHotEmbUpdateStep(); map scInfo; for (const auto& info: eInfos) { embInfos[info.name] = info; scInfo[info.name] = info.sendCount; - if (rankInfo.useHot) { - InitHotEmbTotCount(info, rInfo); - } - if (rankInfo.useDynamicExpansion) { - // 动态扩容 - embeddingTableMap[info.name].Init(info, rInfo, seed); - LOG_INFO(KEY_PROCESS "EmbeddingTableMap:{} init success", info.name); - } + InitHotEmbTotCount(info, rInfo); } LOG_INFO(KEY_PROCESS "hot emb count info:{}", MapToString(hotEmbTotCount)); @@ -82,15 +78,8 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos LOG_WARN(KEY_PROCESS "Feature admit-and-evict function is unavailable ..."); } - if (GlobalEnv::fastUnique) { - int result = ock::ctr::Factory::Create(factory); - if (result != 0) { - throw runtime_error(Logger::Format("create fast factory failed, error code:{}", result)); - } - } - - LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", - MapToString(scInfo), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); + LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}", + MapToString(scInfo), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic); #ifndef GTEST Start(); #endif @@ -132,12 +121,8 @@ int KeyProcess::Start() void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) { - int embeddingSize = info.extEmbeddingSize; - if (rankInfo.useDynamicExpansion) { - embeddingSize = info.embeddingSize; - } hotEmbTotCount[info.name] = static_cast(static_cast(GetUBSize(rInfo.deviceId) / sizeof(float)) * - HOT_EMB_CACHE_PCT / static_cast(embeddingSize)); + HOT_EMB_CACHE_PCT / static_cast(info.embeddingSize)); } OffsetMemT KeyProcess::GetMaxOffset() @@ -342,11 +327,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector name] != SingleEmbTableStatus::SETS_NONE) { tie(splitKeys, restore, keyCount) = HashSplitWithFAAE(batch); // 按存储dev id切分并去重 } else { - if (rankInfo.useHot) { - tie(splitKeys, restore, hotPos) = HotHashSplit(batch); // 按存储dev id切分并去重 - } else { - tie(splitKeys, restore) = HashSplit(batch); // 按存储dev id切分并去重 - } + tie(splitKeys, restore, hotPos) = HotHashSplit(batch); // 按存储dev id切分并去重 } LOG_DEBUG("uniqueTc(ms):{}", uniqueTc.ElapsedMS()); } @@ -385,26 +366,32 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch // Static all2all,need send count if (!rankInfo.useStatic) { SendA2A(uniqueInfo.all2AllInfo.scAll, batch->name, batch->channel, batch->batchId); } + TimeCost pushResultTC; auto tensors = make_unique>(); tensors->push_back(Vec2TensorI32(uniqueInfo.restore)); - if (rankInfo.useHot) { - uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name], -1); - tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); - } + + uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name], -1); + tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); if (!rankInfo.isDDR) { PushGlobalUniqueTensors(move(tensors), uniqueInfo.all2AllInfo.keyRecv, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv) : - Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); + Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); + PushResultHBM(batch, move(tensors)); + } else { + std::vector lookupKeysUint(uniqueInfo.all2AllInfo.keyRecv.begin(), + uniqueInfo.all2AllInfo.keyRecv.end()); + vector uniqueKeys; + vector restoreVecSec; + GlobalUnique(lookupKeysUint, uniqueKeys, restoreVecSec); + PushResultDDR(batch, move(tensors), uniqueKeys, restoreVecSec); } - TimeCost pushResultTC; - PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); + LOG_DEBUG("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); if (GlogConfig::gStatOn) { LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} key_process_time_cost_with_fast_unique {}", channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); } - LOG_DEBUG("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); return true; } @@ -449,17 +436,22 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, TimeCost pushResultTC; auto tensors = make_unique>(); tensors->push_back(Vec2TensorI32(restore)); - if (rankInfo.useHot) { - hotPos.resize(hotEmbTotCount[batch->name], 0); - tensors->push_back(Vec2TensorI32(hotPos)); - } + + hotPos.resize(hotEmbTotCount[batch->name], 0); + tensors->push_back(Vec2TensorI32(hotPos)); if (!rankInfo.isDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); + PushResultHBM(batch, move(tensors)); + } else { + std::vector lookupKeysUint(lookupKeys.begin(), lookupKeys.end()); + vector uniqueKeys; + vector restoreVecSec; + GlobalUnique(lookupKeysUint, uniqueKeys, restoreVecSec); + PushResultDDR(batch, move(tensors), uniqueKeys, restoreVecSec); } - PushResult(batch, move(tensors), lookupKeys); LOG_DEBUG("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); if (GlogConfig::gStatOn) { LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} key_process_time_cost {}", @@ -470,8 +462,9 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tensors, KeysT& lookupKeys, int channel) { - if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && - channel == TRAIN_CHANNEL_ID) { + LOG_INFO(KEY_PROCESS "rank:{}, channel:{}, useSumSameIdGradients:{} ...", + rankInfo.rankId, channel, rankInfo.useSumSameIdGradients); + if (rankInfo.useSumSameIdGradients && channel == TRAIN_CHANNEL_ID) { KeysT uniqueKeys; vector restoreVecSec; @@ -516,15 +509,22 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, in return countRecv; } -void KeyProcess::PushResult(unique_ptr& batch, unique_ptr> tensors, - KeysT& lookupKeys) +void KeyProcess::PushResultHBM(unique_ptr& batch, unique_ptr> tensors) { std::unique_lock lockGuard(mut); storage.push_front(move(tensors)); infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); - if (rankInfo.isDDR) { - lookupKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(lookupKeys))); - } + lockGuard.unlock(); +} + +void KeyProcess::PushResultDDR(unique_ptr& batch, unique_ptr> tensors, + std::vector& uniqueKeys, std::vector& restoreVecSec) +{ + std::unique_lock lockGuard(mut); + storage.push_front(move(tensors)); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); + uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(uniqueKeys))); + restoreVecSecList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(restoreVecSec))); lockGuard.unlock(); } @@ -651,17 +651,15 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uniqu absl::flat_hash_map hotMap = hotKey[batch->name]; lock.unlock(); - if (rankInfo.useHot) { - int hotOffset = 0; - uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); - hotOffset = hotEmbTotCount[batch->name]; + int hotOffset = 0; + uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); + hotOffset = hotEmbTotCount[batch->name]; - TimeCost computeHotTc; - ComputeHotPos(batch, hotMap, uniqueInfoOut.hotPos, uniqueInfoOut.restore, hotOffset); - LOG_DEBUG("ComputeHot TimeCost(ms):{}", computeHotTc.ElapsedMS()); - UpdateHotMapForUnique(keySendInfo.keySend, keySendInfo.keyCount, - hotOffset, batch->batchId % hotEmbUpdateStep == 0, batch->name); - } + TimeCost computeHotTc; + ComputeHotPos(batch, hotMap, uniqueInfoOut.hotPos, uniqueInfoOut.restore, hotOffset); + LOG_DEBUG("ComputeHot TimeCost(ms):{}", computeHotTc.ElapsedMS()); + UpdateHotMapForUnique(keySendInfo.keySend, keySendInfo.keyCount, + hotOffset, batch->batchId % hotEmbUpdateStep == 0, batch->name); if (rankInfo.useStatic) { sc.resize(rankInfo.rankSize, embInfos[batch->name].sendCount); @@ -1111,40 +1109,6 @@ void KeyProcess::Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channe embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); } -void KeyProcess::Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel) -{ - TimeCost key2OffsetTC; - EASY_FUNCTION(profiler::colors::Blue600) - std::lock_guard lk(mut); // lock for PROCESS_THREAD - auto& key2Offset = keyOffsetMap[embName]; - auto& maxOffsetTmp = maxOffset[embName]; - auto& curEmbTable = embeddingTableMap[embName]; // empty when not use dynamic expansion - for (long& key : splitKey) { - if (key == -1) { - key = 0; - continue; - } - const auto& iter = key2Offset.find(key); - if (iter != key2Offset.end()) { - key = iter->second; - } else { - // 新值 - if (channel == TRAIN_CHANNEL_ID) { -#ifndef GTEST - int64_t addr = curEmbTable.GetEmbAddress(); - key2Offset[key] = addr; - key = addr; -#endif - maxOffsetTmp++; - continue; - } - key = 0; - } - } - LOG_DEBUG("current expansion emb:{}, usage:{}/{}, key2OffsetTC({} ms)", - embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); -} - /* * 构建恢复向量,以便从去重后的emb向量/key恢复回batch对应的emb向量 * 输入接收到emb块的偏移blockOffset,batch内每个key在块内的偏移restoreVec @@ -1172,33 +1136,124 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vecto } template -T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, int channel) +T KeyProcess::GetInfo(info_list_t& list, const EmbBaseInfo &info) { std::lock_guard lockGuard(mut); - if (list[embName][channel].empty()) { + if (list[info.name][info.channelId].empty()) { LOG_TRACE("get info list is empty."); throw EmptyList(); } - auto topBatch = get(list[embName][channel].top()); - if (topBatch < batch) { - LOG_ERROR("wrong batch id, top:{} getting:{}, channel:{}, may not clear channel", topBatch, batch, channel); + auto topBatch = get(list[info.name][info.channelId].top()); + if (topBatch < info.batchId) { + LOG_ERROR("wrong batch id, top:{} getting:{}, channel:{}, may not clear channel", + topBatch, info.batchId, info.channelId); this_thread::sleep_for(1s); } - if (topBatch != batch) { - LOG_TRACE("topBatch({}) is not equal batch({}).", topBatch, batch); + if (topBatch != info.batchId) { + LOG_TRACE("topBatch({}) is not equal batch({}).", topBatch, info.batchId); throw WrongListTop(); } - auto t = list[embName][channel].top(); - list[embName][channel].pop(); + auto t = list[info.name][info.channelId].top(); + list[info.name][info.channelId].pop(); return move(t); } -/// DDR模式下,从list中获取查询tensor向量 -/// \param batch 已处理的batch数 -/// \param embName 表名 -/// \param channel 通道索引(训练/推理) -/// \return -KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) +vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, + map &lookUpSwapInAddrsPushId) +{ + TimeCost tc = TimeCost(); + + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + bool cancelMonitor = false; + thread timeoutMonitor; + if (info.batchId != 0) { + timeoutMonitor = StartEosMonitorThread(info, cancelMonitor); + } + + // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空vector + + vector ret; + auto startTime = std::chrono::system_clock::now(); + while (true) { + if (!isRunning) { + break; + } + auto endTime = std::chrono::system_clock::now(); + // 判断此时的info.batchId id是否已经过期,即通道已经刷新 + if (info.batchId != hybridMgmtBlock->hybridBatchId[info.channelId]) { + LOG_DEBUG(KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! {}[{}]:{}", + info.name, info.channelId, info.batchId); + break; + } + if (info.batchId != 0 && info.channelId != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { + LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", + info.name, info.channelId, info.batchId); + break; + } + try { + auto infoVec = GetInfo(uniqueKeysList, info); + ret = get>(infoVec); + break; + } catch (EmptyList&) { + unique_lock lockEosGuard(eosMutex); + isEos = IsGetUniqueKeysEos(info, startTime, lookUpSwapInAddrsPushId); + if (isEos) { + break; + } + this_thread::sleep_for(1ms); + } catch (WrongListTop&) { + LOG_TRACE("getting info failed table:{}, channel:{}, mgmt batchId:{}, wrong top", + info.name, info.channelId, info.channelId); + this_thread::sleep_for(1ms); + } + } + cancelMonitor = true; + if (timeoutMonitor.joinable()) { + timeoutMonitor.join(); + } + return ret; +} + +bool KeyProcess::IsGetUniqueKeysEos(const EmbBaseInfo& info, std::chrono::_V2::system_clock::time_point& startTime, + map& lookUpSwapInAddrsPushId) +{ + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + auto endTime = std::chrono::system_clock::now(); + + // readEmbKey start with 0 + int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; + // 避免eos在keyProcess还未处理完数据时插队到通道前面 + std::chrono::duration elapsedTime = endTime - startTime; + // train and eval batch total num + int allChannelBatchId = 0; + if (info.channelId == EVAL_CHANNEL_ID) { + allChannelBatchId = hybridMgmtBlock->evalBatchIdTotal + hybridMgmtBlock->hybridBatchId[TRAIN_CHANNEL_ID] + + hybridMgmtBlock->readEmbedBatchId[info.channelId]; + } else { + allChannelBatchId = hybridMgmtBlock->evalBatchIdTotal + hybridMgmtBlock->readEmbedBatchId[info.channelId]; + } + if (info.batchId != 0 && elapsedTime.count() >= timeoutGetUniqueKeysEmpty) { + LOG_DEBUG("table:{}, channelId:{}, isNeedSendEos:{}, readEmbKeyBatchId:{}, batch:{}, h2dNextBatchId:{}," + " lookUpSwapInAddrsPushId:{}, allChannelBatchId:{}", info.name, info.channelId, + isNeedSendEos[info.channelId], readEmbKeyBatchId, info.batchId, + hybridMgmtBlock->h2dNextBatchId[info.name], lookUpSwapInAddrsPushId[info.name], allChannelBatchId); + startTime = std::chrono::system_clock::now(); + } + // Check '>= readEmbedBatchIdAll' condition to avoid send eos before handle all batch data from readEmbKey Op. + if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && + hybridMgmtBlock->h2dNextBatchId[info.name] == lookUpSwapInAddrsPushId[info.name] && + hybridMgmtBlock->h2dNextBatchId[info.name] >= allChannelBatchId) { + LOG_INFO("table:{}, channelId:{} batchId:{}, GetUniqueKeys eos, h2dNextBatchId:{}, allChannelBatchId:{}", + info.name, info.channelId, info.batchId, hybridMgmtBlock->h2dNextBatchId[info.name], + allChannelBatchId); + return true; + } + LOG_TRACE("getting uniqueKeys failed, table:{}, channel:{}, mgmt batchId:{}, readEmbKey batchId:{}, list is empty", + info.name, info.channelId, info.batchId, readEmbKeyBatchId); + return false; +} + +std::vector KeyProcess::GetRestoreVecSec(const EmbBaseInfo& info) { TimeCost tc = TimeCost(); // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空vector @@ -1208,74 +1263,80 @@ KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } // 判断此时的batch id是否已经过期,即通道已经刷新 HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - if (batch != hybridMgmtBlock->hybridBatchId[channel]) { + if (info.batchId != hybridMgmtBlock->hybridBatchId[info.channelId]) { LOG_DEBUG(KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! {}[{}]:{}", - embName, channel, batch); + info.name, info.channelId, info.batchId); return {}; } - if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { - LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); + if (info.batchId != 0 && info.channelId != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { + LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", info.name, info.channelId, info.batchId); return {}; } try { - auto ret = GetInfo(lookupKeysList, batch, embName, channel); - return get(ret); + auto ret = GetInfo(restoreVecSecList, info); + return get>(ret); } catch (EmptyList&) { unique_lock lockEosGuard(eosMutex); // readEmbKey真实的次数是readEmbedBatchId减1 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; + int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; // 避免eos在keyProcess还未处理完数据时插队到通道前面 - if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { - LOG_INFO("channelId:{} batchId:{}, GetLookupKeys eos.", channel, batch); - unique_lock lockDestroyGuard(destroyMutex); - SendEos(batch, channel); - return {}; + if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && + hybridMgmtBlock->h2dNextBatchId[info.name] == info.batchId) { + LOG_ERROR("channelId:{} batchId:{}, GetRestoreVecSec eos, code should not reach here", + info.channelId, info.batchId); + throw runtime_error("GetRestoreVecSec eos, code should not reach here"); } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", - embName, channel, batch, readEmbKeyBatchId); + info.name, info.channelId, info.batchId, readEmbKeyBatchId); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); + LOG_TRACE("getting info failed {}[{}]:{} wrong top", info.name, info.channelId, info.batchId); this_thread::sleep_for(1ms); } } } /// 当数据列表为空,且eos标志位为true时,主动发送eos +/// \param embName 表名 /// \param batchId 已处理的batch数 /// \param channel 通道索引(训练/推理) -void KeyProcess::SendEos(int batchId, int channel) +/// \param sendAllChannel 是否强制发送所有channel +void KeyProcess::SendEos(const std::string& embName, int batchId, int channel, bool sendAllChannel) { #ifndef GTEST - LOG_INFO("channelId:{} batchId:{}, SendEos start.", channel, batchId); - - auto trans = Singleton::GetInstance(); - unordered_map transChannels = trans->GetTransChannel(); - std::set usedChannelNames = trans->GetUsedTransChannel()[channel]; - - vector tensors; - bool isNeedResend = true; - - for (const auto& emb: as_const(embInfos)) { // 一个表触发以后,其余表都发送eos,最后外层接收null退出此次循环 - LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos start.", channel, batchId, emb.first); - if (!isRunning) { - throw EndRunExit("SendEos end run, isRunning is false after lock destroyMutex."); - } - for (const string& transName : usedChannelNames) { - string sendName = StringFormat("%s_%s_%d", emb.first.c_str(), transName.c_str(), channel); - size_t channelSize = 0; - - acltdtQueryChannelSize(transChannels[sendName], &channelSize); - LOG_INFO("[EOS] Before send eos, {} contains {}.", sendName, channelSize); - SendTensorsByAcl(transChannels[sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, isNeedResend); - acltdtQueryChannelSize(transChannels[sendName], &channelSize); - LOG_INFO("[EOS] After send eos, {} contains {}.", sendName, channelSize); - } - LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos end.", channel, batchId, emb.first); + finishSendEosCnt[channel].store(0); + ++readySendEosCnt[channel]; + LOG_INFO("table:{}, channelId:{} batchId:{}, readySendEosCnt:{}, ready to SendEos", + embName, channel, batchId, readySendEosCnt[channel]); + while (readySendEosCnt[channel] != static_cast(embInfos.size())) { + LOG_DEBUG("table:{}, readySendEosCnt:{}, waiting other table enter SendEos", embName, readySendEosCnt[channel]); + this_thread::sleep_for(1000ms); + } + LOG_INFO("table:{}, channelId:{} batchId:{}, SendEos start, acquiring destroyMutex", embName, channel, batchId); + destroyMutex.lock(); + + LOG_INFO("table:{}, channelId:{} batchId:{}, SendEos start", embName, channel, batchId); + if (!isRunning) { + LOG_INFO("other table trigger eos ahead, keyProcess already destroyed. skip sending eos for table:{}", embName); + ++finishSendEosCnt[channel]; + destroyMutex.unlock(); + return; } + SendEosTensor(embName, channel, sendAllChannel); + destroyMutex.unlock(); + LOG_INFO("channelId:{} batchId:{}, the embName:{} SendEos end, release destroyMutex", channel, batchId, embName); - LOG_INFO("channelId:{} batchId:{}, SendEos end.", channel, batchId); + ++finishSendEosCnt[channel]; + LOG_INFO("table:{}, channelId:{} batchId:{}, finishSendEosCnt:{}, finish SendEos", + embName, channel, batchId, finishSendEosCnt[channel]); + while (finishSendEosCnt[channel] != static_cast(embInfos.size())) { + LOG_DEBUG("table:{}, channelId:{} batchId:{}, finishSendEosCnt:{}, waiting other table finish SendEos", + embName, channel, batchId, finishSendEosCnt[channel]); + this_thread::sleep_for(1000ms); + } + readySendEosCnt[channel].store(0); isNeedSendEos[channel] = false; + LOG_DEBUG("isNeedSendEos set to false, table:{}, channelId:{} batchId:{}", embName, channel, batchId); #endif } @@ -1285,7 +1346,7 @@ void KeyProcess::SendEos(int batchId, int channel) /// \param channel 通道索引(训练/推理) /// \param type 数据类型 /// \return -unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) +unique_ptr> KeyProcess::GetInfoVec(const EmbBaseInfo &info, ProcessedInfo type, bool &isEos) { TimeCost tc = TimeCost(); info_list_t* list; @@ -1302,47 +1363,46 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa throw std::invalid_argument("Invalid ProcessedInfo Type."); } + unique_ptr> ret = nullptr; // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空指针 while (true) { if (!isRunning) { - return nullptr; + break; } // 判断此时的batch id是否已经过期,即通道已经刷新 HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - if (batch != hybridMgmtBlock->hybridBatchId[channel]) { + if (info.batchId != hybridMgmtBlock->hybridBatchId[info.channelId]) { LOG_DEBUG(KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! {}[{}]:{}", - embName, channel, batch); - return nullptr; + info.name, info.channelId, info.batchId); + break; } - if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { - LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); - return nullptr; + if (info.batchId != 0 && info.channelId != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { + LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", info.name, info.channelId, info.batchId); + break; } try { - auto ret = GetInfo(*list, batch, embName, channel); - auto it = get>>::iterator>(ret); - auto uTensor = move(*it); + auto infoVec = GetInfo(*list, info); + auto it = get>>::iterator>(infoVec); + ret = std::move(*it); std::unique_lock lockGuard(mut); storage.erase(it); - return uTensor; + break; } catch (EmptyList&) { unique_lock lockEosGuard(eosMutex); - // 避免eos在keyProcess还未处理完数据时插队到通道前面, readEmbKey真实的次数是readEmbedBatchId减1 - if (isNeedSendEos[channel] && (hybridMgmtBlock->readEmbedBatchId[channel] - 1) < batch) { - LOG_INFO("channelId:{} batchId:{}, GetInfoVec eos.", channel, batch); - unique_lock lockDestroyGuard(destroyMutex); - SendEos(batch, channel); - return nullptr; + isEos = IsGetInfoVecEos(info.batchId, info.name, info.channelId); + if (isEos) { + break; } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", - embName, channel, batch, (hybridMgmtBlock->readEmbedBatchId[channel] - 1)); + info.name, info.channelId, info.batchId, (hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1)); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); + LOG_TRACE("getting info failed {}[{}]:{} wrong top", info.name, info.channelId, info.batchId); this_thread::sleep_for(1ms); } } + return ret; } void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int channel, int batch) @@ -1369,13 +1429,13 @@ int KeyProcess::GetMaxStep(int channelId) const return rankInfo.ctrlSteps.at(channelId); } -void KeyProcess::EvictKeys(const string& embName, const vector& keys) // hbm +void KeyProcess::EvictKeys(const string& embName, const vector& keys) // hbm { LOG_INFO(KEY_PROCESS "hbm funEvictCall: [{}]! keySize:{}", embName, keys.size()); EmbeddingMgmt::Instance()->EvictKeys(embName, keys); } -void KeyProcess::EvictKeysCombine(const vector& keys) // hbm +void KeyProcess::EvictKeysCombine(const vector& keys) // hbm { LOG_INFO(KEY_PROCESS "hbm combine funEvictCall, keySize:{}", keys.size()); EmbeddingMgmt::Instance()->EvictKeysCombine(keys); @@ -1480,7 +1540,94 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) void KeyProcess::SetEos(int status, int channelId) { unique_lock lockGuard(eosMutex); - LOG_INFO("isNeedSendEos status is changed, before status:[{}], input status:{}, channel:[{}], ", - isNeedSendEos[channelId], status, channelId); + LOG_INFO("isNeedSendEos status is changed, channel:{}, before status:{}, input status:{}", + channelId, isNeedSendEos[channelId], status); isNeedSendEos[channelId] = (status == 1); } + +bool KeyProcess::IsGetInfoVecEos(int batch, const string& embName, int channel) +{ + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + + // 避免eos在keyProcess还未处理完数据时插队到通道前面, readEmbKey真实的次数是readEmbedBatchId减1 + int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; + if (rankInfo.isDDR) { + if (isNeedSendEos[channel] && readEmbKeyBatchId < batch && + hybridMgmtBlock->h2dNextBatchId[embName] == batch) { + LOG_ERROR("channelId:{} batchId:{}, GetInfoVec eos, code should not reach here", channel, batch); + throw runtime_error("GetInfoVec eos, code should not reach here"); + } + } else { + LOG_TRACE("table:{}, channelId:{}, readEmbKeyBatchId:{}, batchId:{}, isNeedSendEos:{}", + embName, channel, readEmbKeyBatchId, batch, isNeedSendEos[channel]); + if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { + LOG_INFO("table:{}, channelId:{} batchId:{}, GetInfoVec eos", embName, channel, batch); + return true; + } + } + return false; +} + +std::thread KeyProcess::StartEosMonitorThread(const EmbBaseInfo &info, bool &cancelMonitor) +{ + // 由于embCache延迟发送swapPos的特性,step n需要step n+1的数据来启动,当获取不到step n+1时,需要触发eos并补发step n需要的swapPos + LOG_DEBUG("table:{}, channel:{}, batchId:{}, start a monitor thread to check eos", + info.name, info.channelId, info.batchId); + return thread([&]() { + chrono::high_resolution_clock::time_point start = chrono::high_resolution_clock::now(); + chrono::high_resolution_clock::time_point end = chrono::high_resolution_clock::now(); + chrono::duration duration = chrono::duration_cast>(end - start); + while (!cancelMonitor && duration.count() < timeoutGetUniqueKeys) { + this_thread::sleep_for(1ms); + end = chrono::high_resolution_clock::now(); + duration = chrono::duration_cast>(end - start); + } + if (!cancelMonitor) { + this->SetEos(1, info.channelId); + LOG_INFO("table:{}, channel:{}, batchId:{}, timeout:{}(s) monitor empty data, set eos", + info.name, info.channelId, info.batchId, timeoutGetUniqueKeys); + } else { + LOG_DEBUG("table:{}, channel:{}, batchId:{}, timeout monitor canceled", + info.name, info.channelId, info.batchId); + } + }); +} + +void KeyProcess::SendEosTensor(const std::string& embName, int channel, bool sendAllChannel) +{ +#ifndef GTEST + auto trans = Singleton::GetInstance(); + unordered_map transChannels = trans->GetTransChannel(); + std::set usedChannelNames = trans->GetUsedTransChannel()[channel]; + + vector tensors; + bool isNeedResend = true; + string sendName; + for (const string& transName : usedChannelNames) { + if (transName == TransferChannel2Str(TransferChannel::SAVE_D2H) || + transName == TransferChannel2Str(TransferChannel::SAVE_H2D)) { + // do nothing on save channel, it's independent to train, eval and predict channel; + continue; + } + + if (transName == TransferChannel2Str(TransferChannel::SWAP) || + transName == TransferChannel2Str(TransferChannel::H2D)) { + sendName = StringFormat("%s_%s_all", embName.c_str(), transName.c_str()); + if (channel == EVAL_CHANNEL_ID && !sendAllChannel) { + LOG_INFO("skip send eos for share channel:{}, channel id:{}", sendName, channel); + LOG_INFO("check if train ProcessEmbInfo run and let it decide eos or not"); + continue; + } + } else { + sendName = StringFormat("%s_%s_%d", embName.c_str(), transName.c_str(), channel); + } + + size_t channelSize = 0; + acltdtQueryChannelSize(transChannels[sendName], &channelSize); + LOG_INFO("[EOS] Before send eos, channel:{}, size:{}.", sendName, channelSize); + SendTensorsByAcl(transChannels[sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, isNeedResend); + acltdtQueryChannelSize(transChannels[sendName], &channelSize); + LOG_INFO("[EOS] After send eos, channel:{}, size:{}.", sendName, channelSize); + } +#endif +} diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 8bd7b8d0795162fd1778dcf00e41ddc793227020..d57130b7b286d884ef62a380079bddeb4f15c90b 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -28,7 +28,6 @@ See the License for the specific language governing permissions and #include "ock_ctr_common/include/factory.h" #include "utils/common.h" -#include "emb_table/emb_table.h" #include "feature_admit_and_evict.h" #include "hybrid_mgmt/hybrid_mgmt_block.h" #include "utils/singleton.h" @@ -83,9 +82,11 @@ namespace MxRec { bool Initialize(const RankInfo& rInfo, const vector& eInfos, const vector& thresholdValues = {}, int seed = 0); - unique_ptr> GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type); + unique_ptr> GetInfoVec(const EmbBaseInfo& info, ProcessedInfo type, bool &isEos); - KeysT GetLookupKeys(int batch, const string& embName, int channel); + vector GetUniqueKeys(const EmbBaseInfo &info, bool &isEos, map &lookUpSwapInAddrsPushId); + + vector GetRestoreVecSec(const EmbBaseInfo& info); int GetMaxStep(int channelId) const; @@ -109,9 +110,9 @@ namespace MxRec { void LoadSaveUnlock(); - void EvictKeys(const string& embName, const vector& keys); + void EvictKeys(const string& embName, const vector& keys); - void EvictKeysCombine(const vector& keys); + void EvictKeysCombine(const vector& keys); void SetupHotEmbUpdateStep(); @@ -157,7 +158,7 @@ namespace MxRec { void SetEos(int status, int channelId); - void SendEos(int batchId, int channel); + void SendEos(const string& embName, int batchId, int channel, bool sendAllChannel); bool isRunning { false }; @@ -167,12 +168,13 @@ namespace MxRec { { return embInfos.find(embName) != embInfos.end(); }; + GTEST_PRIVATE: int Start(); template - T GetInfo(info_list_t& list, int batch, const string& embName, int channel); + T GetInfo(info_list_t& list, const EmbBaseInfo &info); RankInfo rankInfo; map embInfos; @@ -181,6 +183,8 @@ namespace MxRec { vector> procThreads {}; std::mutex loadSaveMut[MAX_CHANNEL_NUM][MAX_KEY_PROCESS_THREAD] {}; info_list_t lookupKeysList; + info_list_t uniqueKeysList; + info_list_t restoreVecSecList; list>> storage; info_list_t infoList; info_list_t all2AllList; @@ -191,11 +195,16 @@ namespace MxRec { map> evictPosMap {}; map> hotKey {}; map hotEmbTotCount; - map embeddingTableMap {}; ock::ctr::FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; - bool isNeedSendEos[2] = { 0, 0 }; // 分别代表通道0、1的eos状态 + + // for end-of-sequence case + bool isNeedSendEos[2] = {false, false}; // 表示各表通道0、1的eos状态 + atomic readySendEosCnt[2]; + atomic finishSendEosCnt[2]; + const double timeoutGetUniqueKeys = 30.0; // 如果超时仍未获取到数据将触发EOS + const double timeoutGetUniqueKeysEmpty = 1.0; // 如果超时仍未获取到数据将打印信息 void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); @@ -240,8 +249,6 @@ namespace MxRec { void Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channel); - void Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel); - unique_ptr GetBatchData(int channel, int commId) const; void BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, @@ -262,7 +269,10 @@ namespace MxRec { void HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, KeySendInfo& keySendInfo, vector& sc, vector& splitSize); - void PushResult(unique_ptr& batch, unique_ptr> tensors, KeysT& lookupKeys); + void PushResultHBM(unique_ptr& batch, unique_ptr> tensors); + + void PushResultDDR(unique_ptr& batch, unique_ptr> tensors, + std::vector& uniqueKeys, std::vector& restoreVecSec); void PushGlobalUniqueTensors(const unique_ptr>& tensors, KeysT& lookupKeys, int channel); @@ -290,6 +300,15 @@ namespace MxRec { } string DumpSplitKeys(vector>& splitKeys) const; + + bool IsGetInfoVecEos(int batch, const string& embName, int channel); + + bool IsGetUniqueKeysEos(const EmbBaseInfo& info, std::chrono::_V2::system_clock::time_point& startTime, + map& lookUpSwapInAddrsPushId); + + void SendEosTensor(const std::string& embName, int channel, bool sendAllChannel); + + std::thread StartEosMonitorThread(const EmbBaseInfo& info, bool& cancelMonitor); }; #define KEY_PROCESS_INSTANCE Singleton::GetInstance() diff --git a/src/core/l3_storage/cache_manager.cpp b/src/core/l3_storage/cache_manager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ea68e1439103cbbbce32e5d082570087b155a42 --- /dev/null +++ b/src/core/l3_storage/cache_manager.cpp @@ -0,0 +1,365 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include "cache_manager.h" + +#include +#include +#include +#include + +#include "utils/common.h" +#include "utils/time_cost.h" + +using namespace MxRec; + +void CacheManager::Init(ock::ctr::EmbCacheManagerPtr embCachePtr, vector& mgmtEmbInfo, + shared_ptr level3Storage) +{ + LOG_INFO("CacheManager Init method begin"); + if (level3Storage == nullptr) { + throw runtime_error("level3Storage is nullptr"); + } + + this->embCache = std::move(embCachePtr); + for (auto& emb : mgmtEmbInfo) { + EmbBaseInfo baseInfo {emb.ssdVocabSize, emb.ssdDataPath, false, emb.extEmbeddingSize}; + embBaseInfos.emplace(emb.name, baseInfo); + preProcessMapper[emb.name].Initialize(emb.name, emb.hostVocabSize, emb.ssdVocabSize); + } + this->l3Storage = level3Storage; + this->l3Storage->Start(); + LOG_INFO("CacheManager Init method end"); +} + +bool CacheManager::IsKeyInL3Storage(const string& embTableName, emb_cache_key_t key) +{ + return l3Storage->IsKeyExist(embTableName, key); +} + +/// 淘汰三级存储中Emb信息 +/// \param embTableName emb表名 +/// \param keys 淘汰key列表 +void CacheManager::EvictL3StorageEmbedding(const string& embTableName, const vector& keys) +{ + if (keys.empty()) { + return; + } + + int keyStep = preProcessStep; + unordered_map& l3StorageMap = preProcessMapper[embTableName].excludeDDRKeyCountMap; + LFUCache& ddrLfu = preProcessMapper[embTableName].lfuCache; + std::vector l3StorageKeysToBeDeleted; + // 1 删除缓存中记录的key的次数 + for (auto &key: keys) { + auto it = l3StorageMap.find(key); + if (it != l3StorageMap.end()) { + l3StorageMap.erase(it); + l3StorageKeysToBeDeleted.emplace_back(key); + } else { + ddrLfu.Pop(key); + } + } + + l3StorageEvictThreads.emplace_back([=]() mutable { + // 2 删除L3Storage中保存的Emb数据 + std::unique_lock lk(evictWaitMut); + evictWaitCond.wait(lk, [keyStep, this] { + return embeddingTaskStep == keyStep; + }); + l3Storage->DeleteEmbeddings(embTableName, l3StorageKeysToBeDeleted); + }); +} + +/// 放入key,新增/更新(次数+1)次数 +/// \param embTableName emb表名 +/// \param key key +/// \param type 记录类型 +void CacheManager::PutKey(const string& embTableName, const emb_key_t& key, RecordType type) +{ + if (type == RecordType::DDR) { + ddrKeyFreqMap[embTableName].Put(key); + return; + } + auto& hashMap = excludeDDRKeyCountMap[embTableName]; + const auto& it = hashMap.find(key); + freq_num_t count = it == hashMap.end() ? 1 : it->second + 1; + hashMap[key] = count; +} + +void CacheManager::CreateL3StorageTableIfNotExist(const std::string& embTableName) +{ + if (embBaseInfos[embTableName].isExist) { + return; + } + if (!l3Storage->IsTableExist(embTableName)) { + l3Storage->CreateTable(embTableName, embBaseInfos[embTableName].savePath, + embBaseInfos[embTableName].maxTableSize); + embBaseInfos[embTableName].isExist = true; + LOG_INFO("create l3Storage table end, embTableName:{}", embTableName); + return; + } + // 续训场景:embBaseInfos 没有保存,不会初始化;L3Storage表会初始化,此时表已存在 + embBaseInfos[embTableName].isExist = true; + LOG_INFO("l3Storage table is exist, embTableName:{}", embTableName); +} + +CacheManager::~CacheManager() +{ + for (auto& t : l3StorageEvictThreads) { + t.join(); + } + l3Storage->Stop(); + ddrKeyFreqMap.clear(); + excludeDDRKeyCountMap.clear(); +} + +/// 加载数据到CacheManager +/// \param ddrFreqInitMap ddr内key频次数据 +/// \param excludeDdrFreqInitMap 非DDR key频次数据 +/// \param step 加载L3Storage传入步数 +void CacheManager::Load(const std::vector &mgmtEmbInfo, int step, + map>& trainKeySet) +{ + // 加载L3Storage数据 +#ifndef GTEST + for (auto& it : embBaseInfos) { + string embTableName = it.first; + EmbBaseInfo& embBase = it.second; + l3Storage->Load(embTableName, embBase.savePath, embBase.maxTableSize, step); + } + auto tableKeysVec = l3Storage->ExportTableKey(); + for (auto &it: tableKeysVec) { + auto &embTableName = it.first; + auto &keys = it.second; + for (auto key: keys) { + preProcessMapper[embTableName].excludeDDRKeyCountMap[key] = 1; + trainKeySet[embTableName].insert(key); + } + } + for (const auto &embInfo: mgmtEmbInfo) { + const std::string &tableName = embInfo.name; + std::vector buffer; + int rc = embCache->Serialize(tableName, buffer); + if (rc != 0) { + throw std::runtime_error("Serialize failed!"); + } + uint64_t memSize = sizeof(uint64_t) + embInfo.extEmbeddingSize * sizeof(float); + for (uint64_t i = 0; i < buffer.size(); i += memSize) { + uint64_t key = *reinterpret_cast(&buffer[i]); + preProcessMapper[tableName].lfuCache.Put(key); + } + } +#endif +} + +void CacheManager::Save(int step) +{ +#ifndef GTEST + l3Storage->Save(step); +#endif +} + +int64_t CacheManager::GetTableUsage(const string& tableName) +{ + if (l3Storage == nullptr) { + throw runtime_error("L3Storage not init"); + } + return l3Storage->GetTableUsage(tableName); +} + +void CacheManager::ProcessSwapOutKeys(const string& tableName, const vector& swapOutKeys, + HBMSwapOutInfo& info) +{ + auto& swapOutDDRKeys = info.swapOutDDRKeys; + auto& swapOutDDRAddrOffs = info.swapOutDDRAddrOffs; + auto& swapOutL3StorageKeys = info.swapOutL3StorageKeys; + auto& swapOutL3StorageAddrOffs = info.swapOutL3StorageAddrOffs; + + // 处理一下没见过的key,看是更新到DDR还是L3Storage中 + auto& keyMapper = preProcessMapper[tableName]; + size_t availableDDRSize = keyMapper.DDRAvailableSize(); + for (size_t i = 0; i < swapOutKeys.size(); ++i) { + emb_cache_key_t key = swapOutKeys[i]; + if (keyMapper.IsDDRKeyExist(key)) { + keyMapper.lfuCache.Put(key); + swapOutDDRKeys.push_back(key); + swapOutDDRAddrOffs.push_back(i); + } else if (keyMapper.IsL3StorageKeyExist(key)) { + keyMapper.excludeDDRKeyCountMap[key]++; + swapOutL3StorageKeys.push_back(key); + swapOutL3StorageAddrOffs.push_back(i); + } else if (availableDDRSize > 0) { + keyMapper.InsertDDRKey(key); + swapOutDDRKeys.push_back(key); + swapOutDDRAddrOffs.push_back(i); + availableDDRSize--; + } else { + keyMapper.InsertL3StorageKey(key); + swapOutL3StorageKeys.push_back(key); + swapOutL3StorageAddrOffs.push_back(i); + } + } +} + +void CacheManager::ProcessSwapInKeys(const string& tableName, const vector& swapInKeys, + vector& DDRToL3StorageKeys, + vector& L3StorageToDDRKeys) +{ + auto& keyMapper = preProcessMapper[tableName]; + size_t externalDDRSize = 0; + std::vector firstSeenKeys; + for (emb_cache_key_t key : swapInKeys) { + if (keyMapper.IsDDRKeyExist(key)) { + continue; + } + externalDDRSize++; + if (keyMapper.IsL3StorageKeyExist(key)) { + L3StorageToDDRKeys.push_back(key); + } else { + firstSeenKeys.push_back(key); + } + } + + auto ddrAvailableSize = keyMapper.DDRAvailableSize(); + if (externalDDRSize > ddrAvailableSize) { // 需要DDR--->L3Storage + size_t transNum = externalDDRSize - ddrAvailableSize; + + if (transNum > keyMapper.L3StorageAvailableSize()) { + throw invalid_argument( + "L3Storage table size too small, key quantity exceed while transferring DDR data to L3Storage"); + } + // DDR--->L3Storage + keyMapper.GetAndDeleteLeastFreqDDRKey2L3Storage(transNum, swapInKeys, DDRToL3StorageKeys); + } + + // L3Storage--->DDR + for (uint64_t key : L3StorageToDDRKeys) { + keyMapper.InsertDDRKey(key); + keyMapper.RemoveL3StorageKey(key); + } + for (uint64_t key : firstSeenKeys) { + keyMapper.InsertDDRKey(key); + } + preProcessStep++; +} + +void CacheManager::UpdateL3StorageEmb(string tableName, float* embPtr, uint32_t extEmbeddingSize, + vector& keys, const vector& swapOutL3StorageOffs) +{ + vector embeddingsAddr(keys.size()); + for (uint64_t i = 0; i < swapOutL3StorageOffs.size(); i++) { + embeddingsAddr[i] = embPtr + swapOutL3StorageOffs[i] * extEmbeddingSize; + } + l3Storage->InsertEmbeddingsByAddr(tableName, keys, embeddingsAddr, extEmbeddingSize); +} + +void CacheManager::TransferDDR2L3Storage(string tableName, uint32_t extEmbeddingSize, vector& keys, + vector& addrs) +{ + CreateL3StorageTableIfNotExist(tableName); + l3Storage->InsertEmbeddingsByAddr(tableName, keys, addrs, extEmbeddingSize); + for (auto addr : addrs) { + free(addr); + addr = nullptr; + } +} + +void CacheManager::FetchL3StorageEmb2DDR(string tableName, uint32_t extEmbeddingSize, vector& keys, + const vector& addrs) +{ + auto embeddings = l3Storage->FetchEmbeddings(tableName, keys); + for (uint64_t i = 0; i < embeddings.size(); i++) { + int rc = memcpy_s(addrs[i], extEmbeddingSize * sizeof(float), embeddings[i].data(), + extEmbeddingSize * sizeof(float)); + if (rc != 0) { + throw runtime_error("memcpy_s failed, rc: " + to_string(rc)); + } + } + l3Storage->DeleteEmbeddings(tableName, keys); + + embeddingTaskStep++; + evictWaitCond.notify_all(); +} + +void CacheManager::BackUpTrainStatus() +{ + ddrKeyFreqMapBackUp = ddrKeyFreqMap; + excludeDDRKeyCountMapBackUp = excludeDDRKeyCountMap; +} + +void CacheManager::RecoverTrainStatus() +{ + for (const auto& pair: excludeDDRKeyCountMapBackUp) { + auto tableName = pair.first; + + std::vector ssdKeysBeforeEval; + std::vector ssdKeysAfterEval; + std::vector swapInKeys; + std::vector swapOutKeys; + + for (const auto& keyMap : pair.second) { + ssdKeysBeforeEval.push_back(keyMap.first); + } + for (const auto& keyMap : excludeDDRKeyCountMap[tableName]) { + ssdKeysAfterEval.push_back(keyMap.first); + } + + GetSwapInAndSwapOutKeys(ssdKeysBeforeEval, ssdKeysAfterEval, swapInKeys, swapOutKeys); + + // ddr <-> ssd + // ddr-> lookup address, ssd->insert embedding , ddr->remove embedding + vector swapInKeysAddr; + int rc = embCache->EmbeddingLookupAddrs(tableName, swapInKeys, swapInKeysAddr); + if (rc != 0) { + throw runtime_error("EmbeddingLookUpAddrs failed! error code: " + std::to_string(rc)); + } + auto extEmbeddingSize = embBaseInfos[tableName].extEmbeddingSize; + l3Storage->InsertEmbeddingsByAddr(tableName, swapInKeys, swapInKeysAddr, extEmbeddingSize); + rc = embCache->EmbeddingRemove(tableName, swapInKeys); + if (rc != 0) { + throw runtime_error("EmbeddingRemove failed! error code: " + std::to_string(rc)); + } + + // ssd->fetch embedding, ddr->EmbeddingUpdate, ssd->delete embedding + auto swapOutEmbeddings = l3Storage->FetchEmbeddings(tableName, swapOutKeys); + vector swapOutFlattenEmbeddings; + for (auto& emb : swapOutEmbeddings) { + swapOutFlattenEmbeddings.insert(swapOutFlattenEmbeddings.cend(), emb.cbegin(), emb.cend()); + } + rc = embCache->EmbeddingUpdate(tableName, swapOutKeys, swapOutFlattenEmbeddings.data()); + l3Storage->DeleteEmbeddings(tableName, swapOutKeys); + } + + ddrKeyFreqMap = ddrKeyFreqMapBackUp; + excludeDDRKeyCountMap = excludeDDRKeyCountMapBackUp; +} + +void CacheManager::GetSwapInAndSwapOutKeys(vector& ssdKeysBeforeEval, + vector& ssdKeysAfterEval, + vector& swapInKeys, vector& swapOutKeys) +{ + std::sort(ssdKeysBeforeEval.begin(), ssdKeysBeforeEval.end()); + std::sort(ssdKeysAfterEval.begin(), ssdKeysAfterEval.end()); + vector intersectionKeys; + std::set_intersection(ssdKeysBeforeEval.begin(), ssdKeysBeforeEval.end(), ssdKeysAfterEval.begin(), + ssdKeysAfterEval.end(), std::back_inserter(intersectionKeys)); + + std::set_difference(ssdKeysBeforeEval.begin(), ssdKeysBeforeEval.end(), intersectionKeys.begin(), + intersectionKeys.end(), std::back_inserter(swapInKeys)); + std::set_difference(ssdKeysAfterEval.begin(), ssdKeysAfterEval.end(), intersectionKeys.begin(), + intersectionKeys.end(), std::back_inserter(swapOutKeys)); +} + diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/l3_storage/cache_manager.h similarity index 38% rename from src/core/ssd_cache/cache_manager.h rename to src/core/l3_storage/cache_manager.h index e750626d65f8c2cb3aa7ede52a882639babad7b9..34e7f0c24a758856fb4c5e4ed16a36fc6bb3bdb0 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/l3_storage/cache_manager.h @@ -23,10 +23,12 @@ See the License for the specific language governing permissions and #include #include "hd_transfer/hd_transfer.h" -#include "host_emb/host_emb.h" #include "lfu_cache.h" #include "ssd_engine/ssd_engine.h" #include "utils/common.h" +#include "preprocess_mapper.h" +#include "ock_ctr_common/include/factory.h" +#include "l3_storage.h" namespace MxRec { @@ -36,14 +38,19 @@ namespace MxRec { size_t devVocabSize; size_t& maxOffset; absl::flat_hash_map& keyOffsetMap; - std::vector& evictDevPos; // 记录HBM内被淘汰的key - std::vector& evictHostPos; // 记录Host内淘汰列表 + }; + + struct HBMSwapOutInfo { + vector swapOutDDRKeys; + vector swapOutDDRAddrOffs; + vector swapOutL3StorageKeys; + vector swapOutL3StorageAddrOffs; }; enum class TransferRet { TRANSFER_OK = 0, // 转移成功或无需处理 TRANSFER_ERROR, - SSD_SPACE_NOT_ENOUGH, + L3STORAGE_SPACE_NOT_ENOUGH, DDR_SPACE_NOT_ENOUGH, }; @@ -67,89 +74,78 @@ namespace MxRec { ~CacheManager(); - void Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo); - - void Load(unordered_map>& ddrFreqInitMap, - unordered_map>& excludeDdrFreqInitMap, - int step, int rankSize, int rankId); + void Init(ock::ctr::EmbCacheManagerPtr embCachePtr, vector& mgmtEmbInfo, + shared_ptr level3Storage); - void SaveSSDEngine(int step); + void Load(const std::vector& mgmtEmbInfo, int step, + map>& trainKeySet); - // 转换DDR和SSD数据 - TransferRet TransferDDREmbWithSSD(TableInfo& table, - const vector& originalKeys, int channelId); + void Save(int step); - /* HBM与DDR换入换出时刷新频次信息 */ - void RefreshFreqInfoCommon(const string& embTableName, vector& keys, - TransferType type); + bool IsKeyInL3Storage(const string& embTableName, emb_cache_key_t key); - bool IsKeyInSSD(const string& embTableName, emb_key_t key); - - void EvictSSDEmbedding(const string& embTableName, vector& keys); + void EvictL3StorageEmbedding(const string& embTableName, const vector& keys); void PutKey(const string& embTableName, const emb_key_t& key, RecordType type); - // DDR内每个表中emb数据频次缓存;map - unordered_map ddrKeyFreqMap; - // 每张表中非DDR内key的出现次数 - unordered_map> excludeDDRKeyCountMap; - - int64_t GetTableEmbeddingSize(const string& tableName); - - private: - struct EmbBaseInfo { - uint64_t maxTableSize; - vector savePath; - bool isExist; - }; + void ProcessSwapOutKeys(const string& tableName, const vector& swapOutKeys, + HBMSwapOutInfo& info); - void GetDDREmbInfo(vector& keys, - TableInfo& table, - vector& ddrTransferPos, vector>& ddrEmbData) const; + void ProcessSwapInKeys(const string& tableName, const vector& swapInKeys, + vector& DDRToL3StorageKeys, + vector& L3StorageToDDRKeys); - void UpdateDDREmbInfo(const std::string& embTableName, - vector& ddrTransferPos, - vector>& ssdEmbData) const; + void UpdateL3StorageEmb(string tableName, float* embPtr, uint32_t extEmbeddingSize,\ + vector& keys, + const vector& swapOutL3StorageAddrOffs); - void RefreshRelateInfoWithDDR2SSD(TableInfo& table, - vector& ddrSwapOutKeys, vector& ddrSwapOutCounts); + void TransferDDR2L3Storage(string tableName, uint32_t extEmbeddingSize, vector& keys, + vector& addrs); - void RefreshRelateInfoWithSSD2DDR(TableInfo& table, - vector& externalSSDKeys, vector& ddrTransferPos); + void FetchL3StorageEmb2DDR(string tableName, uint32_t extEmbeddingSize, vector& keys, + const vector& addrs); - void GetSSDKeys(const std::string& embTableName, vector& externalKeys, - vector& externalSSDKeys); + int64_t GetTableUsage(const string& tableName); - TransferRet TransferDDREmb2SSD(TableInfo& table, - int64_t ddrSwapOutSize, const vector& keys, - vector& ddrTransferPos); + void BackUpTrainStatus(); - TransferRet TransferSSDEmb2DDR(TableInfo& table, - vector& externalSSDKeys, vector& ddrTransferPos, - vector>& ssdEmbData); + void RecoverTrainStatus(); - void CreateSSDTableIfNotExist(const std::string& embTableName); + void GetSwapInAndSwapOutKeys(vector& ssdKeysBeforeEval, + vector& ssdKeysAfterEval, + vector& swapInKeys, vector& swapOutKeys); - void RestoreLeastFreqInfo(const std::string& embTableName, vector& ddrSwapOutKeys, - vector& ddrSwapOutCounts); + // DDR内每个表中emb数据频次缓存;map + unordered_map ddrKeyFreqMap; + unordered_map ddrKeyFreqMapBackUp; + // 每张表中非DDR内key的出现次数 + unordered_map> excludeDDRKeyCountMap; + unordered_map> excludeDDRKeyCountMapBackUp; - static void HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, - TableInfo& table); + // 每一个table对应一个PreProcessMapper,预先推演HBM->DDR的情况 + std::unordered_map preProcessMapper; - inline void GetExternalKeys(const absl::flat_hash_map &keyOffsetMap, - vector& externalKeys, - vector& internalKeys, const vector& keys) const; + int preProcessStep = 0; + int embeddingTaskStep = 0; + std::mutex evictWaitMut; + std::condition_variable evictWaitCond; - void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, - vector& externalSSDKeys) const; + private: + struct EmbBaseInfo { + uint64_t maxTableSize; + vector savePath; + bool isExist; + int extEmbeddingSize; + }; - void HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) const; + void CreateL3StorageTableIfNotExist(const std::string& embTableName); unordered_map embBaseInfos; GTEST_PRIVATE: - shared_ptr ssdEngine = std::make_shared(); - HostEmb* hostEmbs {}; + shared_ptr l3Storage; + vector l3StorageEvictThreads; + ock::ctr::EmbCacheManagerPtr embCache {}; }; } diff --git a/src/core/l3_storage/l3_storage.cpp b/src/core/l3_storage/l3_storage.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc26d8a4dc3e693d5f58b8629b83f17d8c62b1d4 --- /dev/null +++ b/src/core/l3_storage/l3_storage.cpp @@ -0,0 +1,72 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include "l3_storage.h" + +using MxRec::L3Storage; +using MxRec::emb_cache_key_t; +using std::vector; +using std::string; + +L3Storage::L3Storage() {} + +L3Storage::~L3Storage() {} + +bool L3Storage::IsTableExist(const string& tableName) +{ + return false; +} + +bool L3Storage::IsKeyExist(const string& tableName, emb_cache_key_t key) +{ + return false; +} + +void L3Storage::CreateTable(const string& tableName, vector savePaths, uint64_t maxTableSize) {} + +int64_t L3Storage::GetTableAvailableSpace(const string& tableName) +{ + return 0; +} + +void L3Storage::InsertEmbeddingsByAddr(const string& tableName, vector& keys, + vector& embeddingsAddr, uint64_t extEmbeddingSize) +{ +} + +void L3Storage::DeleteEmbeddings(const string& tableName, vector& keys) {} + +vector> L3Storage::FetchEmbeddings(const string& tableName, vector& keys) +{ + return vector>(); +} + +void L3Storage::Save(int step) {} + +void L3Storage::Load(const string& tableName, vector savePaths, uint64_t maxTableSize, int step) {} + +void L3Storage::Start() {} + +void L3Storage::Stop() {} + +int64_t L3Storage::GetTableUsage(const string& tableName) +{ + return 0; +} + +vector>> L3Storage::ExportTableKey() +{ + return vector>>(); +} diff --git a/src/core/emb_table/emb_table.h b/src/core/l3_storage/l3_storage.h similarity index 30% rename from src/core/emb_table/emb_table.h rename to src/core/l3_storage/l3_storage.h index 2d30818c1c977ec2c323ac9ec034b2dff68f90b9..5f7270c128d502a7578f614a1b819f4e1de0912c 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/l3_storage/l3_storage.h @@ -13,81 +13,49 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MX_REC_EMB_TABLE_H -#define MX_REC_EMB_TABLE_H +#ifndef MX_REC_L3_STORAGE_H +#define MX_REC_L3_STORAGE_H -#include -#include -#include -#include +#include +#include #include "utils/common.h" namespace MxRec { - using namespace std; +class L3Storage { +public: + L3Storage(); + virtual ~L3Storage(); - class EmbTable { - public: - EmbTable() = default; + virtual bool IsTableExist(const std::string& tableName); - void Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed = 0); + virtual bool IsKeyExist(const std::string& tableName, emb_cache_key_t key); - ~EmbTable(); + virtual void CreateTable(const std::string& tableName, std::vector savePaths, uint64_t maxTableSize); - // 从embeddingList获取获取一个可用的emb地址 - int64_t GetEmbAddress(); + virtual int64_t GetTableAvailableSpace(const std::string& tableName); - // 打印emb表使用情况 - void PrintStatus() const; + virtual void InsertEmbeddingsByAddr(const std::string& tableName, std::vector& keys, + std::vector& embeddingsAddr, uint64_t extEmbeddingSize); - int64_t GetTableSize() const; + virtual void DeleteEmbeddings(const std::string& tableName, std::vector& keys); - int64_t GetTableCapacity() const; + virtual std::vector> FetchEmbeddings(const std::string& tableName, + std::vector& keys); - EmbTable(const EmbTable&) = delete; + virtual void Save(int step); - EmbTable(EmbTable&&) = delete; + virtual void Load(const std::string& tableName, std::vector savePaths, uint64_t maxTableSize, + int step); - EmbTable& operator=(const EmbTable&) = delete; + virtual void Start(); - EmbTable& operator=(EmbTable&&) = delete; + virtual void Stop(); - void ExecuteAclMemcpy(void* newBlock, vector devEmb) const; + virtual int64_t GetTableUsage(const std::string& tableName); - GTEST_PRIVATE: - constexpr static int BLOCK_EMB_COUNT = 100000; - constexpr static int INIT_BLOCK_COUNT = 5; - constexpr static int TEST_EMB_SIZE = 12; - EmbInfo embInfo; - RankInfo rankInfo; - size_t blockSize = 1; - int embSize = 1; - size_t totalCapacity = 1; - size_t usedCapacity = 0; - int seed = 0; - // embedding地址的列表 - list embeddingList; - // 内存块列表 - vector memoryList; - - void RandomInit(void* newBlock); - - // embSize由embInfo得出 - void SplitMemoryBlock(void* newBlock); - - // 内部类,抛出内存不足异常 - class OutOfMemoryError : public runtime_error { - public: - OutOfMemoryError() : runtime_error("Out of memory!") {} - }; - - // 内部类,抛出acl异常 - class AclError : public runtime_error { - public: - AclError() : runtime_error("Acl failed!") {} - }; - }; -} - -#endif // MX_REC_EMB_TABLE_MANAGER_H \ No newline at end of file + virtual std::vector>> ExportTableKey(); +}; +} // namespace MxRec +#endif // MX_REC_L3_STORAGE_H \ No newline at end of file diff --git a/src/core/ssd_cache/lfu_cache.cpp b/src/core/l3_storage/lfu_cache.cpp similarity index 79% rename from src/core/ssd_cache/lfu_cache.cpp rename to src/core/l3_storage/lfu_cache.cpp index c204e336bd848ddaaac984a2f410df6664bbad12..c2d38bd2ced468a9cca2e6d6656fc6a0e940fcd0 100644 --- a/src/core/ssd_cache/lfu_cache.cpp +++ b/src/core/l3_storage/lfu_cache.cpp @@ -25,7 +25,7 @@ using namespace MxRec; /// 仅获取当前key的频次,不增加频次;key不存在时返回-1 /// \param key key /// \return key的频次 -freq_num_t LFUCache::Get(emb_key_t key) +freq_num_t LFUCache::Get(emb_cache_key_t key) { auto it = keyTable.find(key); if (it == keyTable.end()) { return -1; } @@ -37,13 +37,16 @@ freq_num_t LFUCache::Get(emb_key_t key) /// \param keys 要返回的最低频次key不能在该列表内 /// \param ddrSwapOutKeys 记录最低频次key /// \param ddrSwapOutCounts 记录最低频次key对应次数 -void LFUCache::GetAndDeleteLeastFreqKeyInfo(int64_t num, const vector& keys, - vector& ddrSwapOutKeys, vector& ddrSwapOutCounts) +void LFUCache::GetAndDeleteLeastFreqKeyInfo(uint64_t num, const vector& keys, + vector& ddrSwapOutKeys, + vector& ddrSwapOutCounts) { freq_num_t tempMinFreq = minFreq; - unordered_set retainedKeySet(keys.begin(), keys.end()); - int64_t counter = 0; + unordered_set retainedKeySet(keys.begin(), keys.end()); + uint64_t counter = 0; const size_t freqSize = freqTable.size(); + LOG_DEBUG("table:{}, num:{}, freqTable.size:{}, keys.size:{}, ddrSwapOutKeys.size:{}, ddrSwapOutCounts.size:{}", + name, num, freqTable.size(), keys.size(), ddrSwapOutKeys.size(), ddrSwapOutCounts.size()); // 遍历freqTable<次数,keyList>时,次数可能不连续,要实际使用了1个keyList后才自增,手动增加计数器 for (size_t i = 0; i < freqSize;) { auto nodesIter = freqTable.find(tempMinFreq); @@ -53,7 +56,7 @@ void LFUCache::GetAndDeleteLeastFreqKeyInfo(int64_t num, const vector } auto nodeIt = freqTable[tempMinFreq].begin(); while (nodeIt != freqTable[tempMinFreq].end() && !freqTable[tempMinFreq].empty() && counter < num) { - emb_key_t currentKey = nodeIt->key; + emb_cache_key_t currentKey = nodeIt->key; if (retainedKeySet.find(currentKey) != retainedKeySet.end()) { // 当前key在指定的集合中,不满足 nodeIt++; @@ -80,7 +83,7 @@ void LFUCache::GetAndDeleteLeastFreqKeyInfo(int64_t num, const vector /// 放入key,新增/更新(次数+1)次数 /// \param key key -void LFUCache::Put(emb_key_t key) +void LFUCache::Put(emb_cache_key_t key) { auto it = keyTable.find(key); if (it == keyTable.end()) { @@ -94,8 +97,10 @@ void LFUCache::Put(emb_key_t key) freqTable[freq].erase(node); if (freqTable[freq].empty()) { freqTable.erase(freq); + if (minFreq == freq) { + minFreq += 1; + } } - if (minFreq == freq) { minFreq += 1; } freqTable[freq + 1].emplace_front(key, freq + 1); keyTable[key] = freqTable[freq + 1].begin(); } @@ -103,7 +108,7 @@ void LFUCache::Put(emb_key_t key) /// 直接放入指定次数;用于初始化场景 /// \param key key /// \param freq 频次 -void LFUCache::PutWithInit(emb_key_t key, freq_num_t freq) +void LFUCache::PutWithInit(emb_cache_key_t key, freq_num_t freq) { if (keyTable.find(key) != keyTable.end()) { // 一般初始化时,key应该不存在已经被插入的情况;此处替换就的key频次信息 @@ -120,7 +125,7 @@ void LFUCache::PutWithInit(emb_key_t key, freq_num_t freq) } /// 删除指定key -bool LFUCache::Pop(emb_key_t key) +bool LFUCache::Pop(emb_cache_key_t key) { auto it = keyTable.find(key); if (it == keyTable.end()) { @@ -139,15 +144,23 @@ bool LFUCache::Pop(emb_key_t key) /// 获取所有的key和次数信息 /// \return 频次数据map -std::unordered_map LFUCache::GetFreqTable() +std::unordered_map LFUCache::GetFreqTable() { - unordered_map freqMap(keyTable.size()); + unordered_map freqMap(keyTable.size()); for (const auto& it :keyTable) { freqMap[it.first] = it.second->freq; } return freqMap; } +LFUCache::LFUCache(const string& cacheName) +{ + name = cacheName; + minFreq = 0; + keyTable.clear(); + freqTable.clear(); +} + LFUCache::LFUCache() { minFreq = 0; diff --git a/src/core/ssd_cache/lfu_cache.h b/src/core/l3_storage/lfu_cache.h similarity index 67% rename from src/core/ssd_cache/lfu_cache.h rename to src/core/l3_storage/lfu_cache.h index 247e490ecefe3799537912a0c23d1439db4626cd..94fde5399ed68d84fac28b26bcbb8c83e9d35d14 100644 --- a/src/core/ssd_cache/lfu_cache.h +++ b/src/core/l3_storage/lfu_cache.h @@ -31,10 +31,10 @@ namespace MxRec { // 记录key和次数信息 struct LFUCacheNode { - emb_key_t key; + emb_cache_key_t key; freq_num_t freq; - LFUCacheNode(emb_key_t key, freq_num_t freq) : key(key), freq(freq) + LFUCacheNode(emb_cache_key_t key, freq_num_t freq) : key(key), freq(freq) {} }; @@ -42,25 +42,29 @@ namespace MxRec { public: LFUCache(); - freq_num_t Get(emb_key_t key); + explicit LFUCache(const string& cacheName); - void GetAndDeleteLeastFreqKeyInfo(int64_t num, const vector& keys, - vector& ddrSwapOutKeys, + freq_num_t Get(emb_cache_key_t key); + + void GetAndDeleteLeastFreqKeyInfo(uint64_t num, const vector& keys, + vector& ddrSwapOutKeys, vector& ddrSwapOutCounts); - void Put(emb_key_t key); + void Put(emb_cache_key_t key); - bool Pop(emb_key_t key); + bool Pop(emb_cache_key_t key); - void PutWithInit(emb_key_t key, freq_num_t freq); + void PutWithInit(emb_cache_key_t key, freq_num_t freq); - std::unordered_map GetFreqTable(); + std::unordered_map GetFreqTable(); // 最小频次 freq_num_t minFreq = 0; // 次数, 该次数对应的key列表(key, freq) std::unordered_map> freqTable; // key, key所属node在freqTable的节点列表中的存储位置地址 - std::unordered_map::iterator> keyTable; + std::unordered_map::iterator> keyTable; + private: + string name; }; } diff --git a/src/core/l3_storage/preprocess_mapper.h b/src/core/l3_storage/preprocess_mapper.h new file mode 100644 index 0000000000000000000000000000000000000000..0fc8e4d8ad9177212a1535e5985d83ab627e89e8 --- /dev/null +++ b/src/core/l3_storage/preprocess_mapper.h @@ -0,0 +1,116 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef MXREC_DDR_PREPROCESS_MAPPER_H +#define MXREC_DDR_PREPROCESS_MAPPER_H + +#include +#include "lfu_cache.h" + +namespace MxRec { + /* + * 专供keys处理的线程使用,每一个emb_local_table就有一个DDRPreProcessMapper + * MapperBase中的桶存储k-v对,在这里value统一赋值为0 + */ + class PreProcessMapper { + public: + void Initialize(const string& embName, size_t ddrVocabSize, size_t l3StorageVocabSize) + { + tableName = embName; + lfuCache = LFUCache(embName); + ddrAvailableSize = ddrVocabSize; + l3StorageAvailableSize = l3StorageVocabSize; + } + + bool IsDDRKeyExist(uint64_t key) + { + return lfuCache.keyTable.find(key) != lfuCache.keyTable.end(); + } + + bool IsL3StorageKeyExist(uint64_t key) + { + return excludeDDRKeyCountMap.find(key) != excludeDDRKeyCountMap.end(); + } + + bool InsertDDRKey(uint64_t key) + { + if (IsDDRKeyExist(key)) { + throw std::invalid_argument("InsertDDRKey failed! key already exist"); + } + + freq_num_t freq = excludeDDRKeyCountMap[key] + 1; + lfuCache.PutWithInit(key, freq); + return true; + } + + bool InsertL3StorageKey(uint64_t key) + { + if (IsL3StorageKeyExist(key)) { + throw std::invalid_argument("InsertL3StorageKey failed! key already exist"); + } + + excludeDDRKeyCountMap[key] = 1; + return true; + } + + bool RemoveL3StorageKey(uint64_t key) + { + if (!IsL3StorageKeyExist(key)) { + throw std::invalid_argument("RemoveKey failed! key not exist"); + } + excludeDDRKeyCountMap.erase(key); + return true; + } + + size_t DDRAvailableSize() + { + if (ddrAvailableSize < lfuCache.keyTable.size()) { + throw std::invalid_argument("ddrAvailableSize < existKeys.size()"); + } + return ddrAvailableSize - lfuCache.keyTable.size(); + } + + size_t L3StorageAvailableSize() + { + if (l3StorageAvailableSize < excludeDDRKeyCountMap.size()) { + throw std::invalid_argument("l3StorageAvailableSize < existKeys.size()"); + } + return l3StorageAvailableSize - excludeDDRKeyCountMap.size(); + } + + void GetAndDeleteLeastFreqDDRKey2L3Storage(uint64_t transNum, const std::vector& keys, + std::vector& DDRSwapOutKeys) + { + LOG_DEBUG("start GetAndDeleteLeastFreqDDRKey2L3Storage, table:{}", tableName); + std::vector DDRSwapOutCounts; + lfuCache.GetAndDeleteLeastFreqKeyInfo(transNum, keys, DDRSwapOutKeys, DDRSwapOutCounts); + for (uint64_t i = 0; i < DDRSwapOutKeys.size(); i++) { + excludeDDRKeyCountMap[DDRSwapOutKeys[i]] = DDRSwapOutCounts[i]; + } + if (DDRSwapOutCounts.size() != transNum) { + throw std::invalid_argument( + "GetAndDeleteLeastFreqDDRKey2L3Storage failed! DDRSwapOutCounts.size()!=transNum"); + } + } + + string tableName; + uint64_t ddrAvailableSize = 0; + uint64_t l3StorageAvailableSize = 0; + LFUCache lfuCache; + std::unordered_map excludeDDRKeyCountMap; + }; +} + +#endif // MXREC_DDR_PREPROCESS_MAPPER_H diff --git a/src/core/ock_ctr_common/include/embedding_cache.h b/src/core/ock_ctr_common/include/embedding_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..ce807f160e17b2ef5917e7386c3f88ae2ce3f95d --- /dev/null +++ b/src/core/ock_ctr_common/include/embedding_cache.h @@ -0,0 +1,341 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.s +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef EMBEDDING_CACHE_H +#define EMBEDDING_CACHE_H + +#include +#include +#include +#include + +namespace EmbCache { +using KeyOffsetPair = std::pair, std::vector>; + +class Initializer { +public: + Initializer() = default; + virtual ~Initializer() = default; + + /* * + * 生成随机数 + * @Param emb embedding的首地址 + */ + virtual void GenerateData(float* emb, int embSize) = 0; + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float initParam = 1.0; // 初始化器生成的初始值均需要乘以initParam +}; + +enum class InitializerType { + INVALID, + CONSTANT, + TRUNCATED_NORMAL, + RANDOM_NORMAL +}; + +struct ConstantInitializerInfo { + ConstantInitializerInfo() = default; + + ConstantInitializerInfo(float constantValue, float initK); + + float constantValue = 0; // 常量值 + float initK = 1.0; // 初始化出来的值需乘以initK +}; + +struct NormalInitializerInfo { + NormalInitializerInfo() = default; + + NormalInitializerInfo(float mean, float stddev, uint32_t seed, float initK); + + float mean = 0; // 平均值 + float stddev = 0; // 标准差 + uint32_t seed = 0; // 随机数种子 + float initK = 1.0; // 初始化出来的值需乘以initK +}; + +class ConstantInitializer : public Initializer { +public: + ConstantInitializer() = default; + + ConstantInitializer(uint32_t start, uint32_t len, float value, float initK); + + ~ConstantInitializer() override = default; + + void GenerateData(float* emb, int embSize) override; + + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float constantValue = 0; // 常量值 +}; + +class RandomNormalInitializer : public Initializer { +public: + RandomNormalInitializer() = default; + RandomNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo& initInfo); + + ~RandomNormalInitializer() override = default; + + void GenerateData(float* emb, int embSize) override; + + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float mean = 0; // 平均值 + float stddev = 0; // 标准差 + uint32_t seed = 0; // 随机数种子 + + std::default_random_engine generator; // 随机数生成器 + std::normal_distribution distribution; // 正态分布 +}; + +class TruncatedNormalInitializer : public Initializer { +public: + TruncatedNormalInitializer() = default; + + TruncatedNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo& initInfo); + + ~TruncatedNormalInitializer() override = default; + + void GenerateData(float* emb, int embSize) override; + + int boundNum = 2; + + uint32_t start = 0; // 起始位置 + uint32_t len = 0; // 初始化的长度 + float mean = 0; // 平均值 + float stddev = 0; // 标准差 + uint32_t seed = 0; // 随机数种子 + + std::default_random_engine generator; // 随机数生成器 + std::normal_distribution distribution; + float minBound = 0; // 下界 + float maxBound = 0; // 上界 +}; + +struct InitializerInfo { + InitializerInfo() = default; + + InitializerInfo(std::string& name, uint32_t start, uint32_t len, ConstantInitializerInfo constantInitializerInfo); + + InitializerInfo(std::string& name, uint32_t start, uint32_t len, NormalInitializerInfo normalInitializerInfo); + + std::string name = ""; // 初始化器的名称 + uint32_t start = 0; // 初始化开始的位置 + uint32_t len = 0; // 待初始化的长度 + InitializerType initializerType = InitializerType::INVALID; + + ConstantInitializerInfo constantInitializerInfo; + NormalInitializerInfo normalInitializerInfo; + + std::shared_ptr initializer; +}; + +struct EmbCacheInfo { + EmbCacheInfo(std::string tableName, uint32_t vocabSize, uint32_t embeddingSize, uint32_t extEmbeddingSize, + uint32_t maxCacheSize) + : tableName(tableName), + vocabSize(vocabSize), + embeddingSize(embeddingSize), + extEmbeddingSize(extEmbeddingSize), + maxCacheSize(maxCacheSize) + { + } + std::string tableName = ""; + uint32_t vocabSize = 0; // host侧的容量(能存多少条embedding) + uint32_t embeddingSize = 0; + uint32_t extEmbeddingSize = 0; // 包含embedding和优化器信息的embedding长度 + uint32_t maxCacheSize = 0; // device侧的容量(能存多少条embedding) +}; + +class EmbCacheManager { +public: + virtual ~EmbCacheManager() = default; + + /* * + * 对当前embInfo对应的table在cache_manager中进行table初始化 + * @Param EmbCacheInfo: embedding cache的初始化信息 + * @Param std::vector 初始化器的信息 + * @Param uint64_t prefillBufferSize emb内存池恒定可用大小 + * @Param uint32_t refillThreadNum emb内存池自动填充线程数 + * @Return errorCode + */ + virtual int CreateCacheForTable(const EmbCacheInfo& embCacheInfo, + const std::vector& initializerInfos, int64_t invalidKey = -1, + uint64_t prefillBufferSize = 500000, uint32_t refillThreadNum = 1) = 0; + + /* * + * 查找当前keys对应的offsets并将本不存在与offsetMapper中的keys插入到offsetMapper中并得到其偏移值offsets, + * 并且当offsetMapper可存放空间不足时,释放可swapOut的keys,获取当前需要被换入换出的keys和offsets的pair + * @Param tableName: 表名 + * @Param keys: 当前batch所有unique的keys + * @Param swapInKoPair: 输出参数,需要换入的Key-offset pair + * @Param swapOutKoPair: 输出参数,需要换出的Key-offset pair + * @Return errorCode + */ + virtual int GetSwapPairsAndKey2Offset(std::string tableName, std::vector& keys, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) = 0; + + /* * + * 查询Embedding + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param embAddr: 申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookup(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 查询Embedding的地址 + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param addrs: keys对应的申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookupAddrs(std::string tableName, const std::vector& keys, + std::vector& addrs, uint32_t threadNum = 4) = 0; + + /* * + * 查询Embedding并且在查询完成之后删除embedding对应的key。如果多线程使用,严格保证传入的key线程间不会重复(unique + * key),否则可能出现未定义结果 + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param embAddr: 申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookupAndRemove(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 更新Embedding + * @Param tableName: 表名 + * @Param keys: 待更新的keys,用于查询出每个key在DDR上存放的地址 + * @Param embAddr: 待更新到DDR上的embedding的首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingUpdate(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 在EmbLocalTable中移除keys,并将存储其embedding的内存位置记为可复用 + * @Param tableName: 表名 + * @Param keys: 待移除的keys + * @Return errorCode + */ + virtual int EmbeddingRemove(std::string tableName, const std::vector& keys, uint32_t threadNum = 4) = 0; + + /* * + * 将需要被淘汰的keys从offsetMapper的记录中移除,同时也在EmbLocalTable中移除,并将存储其embedding的内存位置记为可复用 + * @Param tableName: 表名 + * @Param keys: 待淘汰的keys + * @Return errorCode + */ + virtual int RemoveEmbsByKeys(std::string tableName, const std::vector& keys) = 0; + + /* * + * 获取所有table names + * @Param allTableNames: 输出参数,用于存放所有的table names + * @Return errorCode + */ + virtual int GetEmbTableNames(std::vector& allTableNames) = 0; + + /* * + * 获取以values为增序排列的当前记录在offsetMapper中所有的keys和values的pairs + * @Param tableName: 表名 + * koVec: 输出参数 + * @Return errorCode + */ + virtual int ExportDeviceKeyOffsetPairs(std::string tableName, + std::vector>& koVec) = 0; + + /* * + * 获取当前table的序列化信息 + * @Param tableName: 要序列化的表 + * @Param buffer: 输出参数,存储序列化之后的信息 + * @Return errorCode + */ + virtual int Serialize(std::string tableName, std::vector& buffer) = 0; + + /* * + * 将当前table的序列化信息进行反序列化 + * @Param tableName: 要反序列化的表 + * @Param buffer: 输入参数,将buffer中的内容进行反序列化 + * @Return errorCode + */ + virtual int Deserialize(std::string tableName, const std::vector& buffer) = 0; + + /* * + * 析构所有embCache,释放内存 + */ + virtual void Destroy() = 0; + + /* * + * 查询表的使用量 + * @Param tableName: 要查询的表 + * @Return 当前表的使用量 + */ + virtual uint32_t GetUsage(const std::string& tableName) = 0; + + /* * + * 获取当前host侧所存储的所有keys及其对应的embeddings和优化器参数 + * @Param tableName: 需要获取信息的table名字 + * @Param keys: 输入参数,输入空vector,获取的存储的所有keys会赋到该vector中 + * @Param embeddings: 输入参数,输入空vector,获取的存储的所有embeddings会赋到该vector中 + * @Param optimizerSlots: 输入参数,输入空vector,获取的存储的所有optimizerSlots会赋到该vector中 + * @Return errorCode + */ + virtual int GetEmbTableInfos(std::string tableName, std::vector& keys, + std::vector>& embeddings, + std::vector>& optimizerSlots) = 0; + + /* * + * 将所需存储的keys及其对应的embeddings和优化器参数传入,来装载LocalEmbeddingTable + * @Param tableName: 需要加载信息的table名字 + * @Param keys: 输入参数,需要加载的所有keys + * @Param embeddings: 输入参数,需要加载的所有embeddings + * @Param optimizerSlots: 输入参数,需要加载的所有optimizerSlots + * @Return errorCode + */ + virtual int LoadEmbTableInfos(std::string tableName, const std::vector& keys, + const std::vector>& embeddings, + const std::vector>& optimizerSlots) = 0; + + /* * + * When switch the channel to eval, backup the current table's offsetMapper object. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int BackUpTrainStatus(const std::string& tableName) = 0; + + /* * + * When switch the eval channel back to train, Recover the current table's offsetMapper object to the backup state. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int RecoverTrainStatus(const std::string& tableName) = 0; + + /* * + * Reset the offsetMapper object to revert to its initialized state after loading. + * @Return errorCode + */ + virtual int ResetOffsetMappers() = 0; +}; +} // namespace EmbCache + +#endif // EMBEDDING_CACHE_H diff --git a/src/core/ock_ctr_common/include/factory.h b/src/core/ock_ctr_common/include/factory.h index 44a2fce01d3951adcfdcc9cd6489d01c5b89938e..ce701abe422cae9e08d4b643a6e3d17dae8074f3 100644 --- a/src/core/ock_ctr_common/include/factory.h +++ b/src/core/ock_ctr_common/include/factory.h @@ -17,16 +17,17 @@ See the License for the specific language governing permissions and #define UNIQUE_OCK_CTR_COMMON_H #include -#include #include -#include "unique.h" +#include +#include "embedding_cache.h" +#include "unique.h" #ifdef __cplusplus extern "C" { #endif -using ExternalLog = void (*)(int level, const char *msg); +using ExternalLog = void (*)(int level, const char* msg); #ifdef __cplusplus } @@ -40,26 +41,28 @@ class Factory; using FactoryPtr = std::shared_ptr; using UniquePtr = std::shared_ptr; +using EmbCacheManagerPtr = std::shared_ptr; class Factory { public: virtual ~Factory() = default; - virtual int CreateUnique(UniquePtr &out) = 0; + virtual int CreateUnique(UniquePtr& out) = 0; + virtual int CreateEmbCacheManager(EmbCacheManagerPtr& out) = 0; virtual int SetExternalLogFuncInner(ExternalLog logFunc) = 0; public: - static int Create(FactoryPtr &out) + static int Create(FactoryPtr& out) { int result = 0; uintptr_t factory = 0; /* dynamic load function */ - if ((result = OckCtrCommonDef::CreatFactory(&factory)) == 0) { - out.reset(reinterpret_cast(factory)); + if ((result = OckCtrCommonDef::CreateFactory(&factory)) == 0) { + out.reset(reinterpret_cast(factory)); } return result; } }; -} -} +} // namespace ctr +} // namespace ock -#endif // UNIQUE_OCK_CTR_COMMON_H +#endif // UNIQUE_OCK_CTR_COMMON_H diff --git a/src/core/ock_ctr_common/include/ock_ctr_common_def.h b/src/core/ock_ctr_common/include/ock_ctr_common_def.h index e8b3f0b5e5e51bdf70580ab149bf902e26b53e9a..537d7a394aaf34382284b43a945b38ad50b2af09 100644 --- a/src/core/ock_ctr_common/include/ock_ctr_common_def.h +++ b/src/core/ock_ctr_common/include/ock_ctr_common_def.h @@ -20,15 +20,15 @@ See the License for the specific language governing permissions and #include #include -using CTR_CREATE_FACTORY_FUNCTION = int (*)(uintptr_t *); +using CTR_CREATE_FACTORY_FUNCTION = int (*)(uintptr_t*); namespace ock { namespace ctr { class OckCtrCommonDef { public: - static int CreatFactory(uintptr_t *factory) + static int CreateFactory(uintptr_t* factory) { - static void *handle = nullptr; + static void* handle = nullptr; static std::mutex m; std::unique_lock lock(m); if (handle != nullptr) { @@ -38,8 +38,8 @@ public: handle = dlopen(LIBRARY_NAME, RTLD_NOW); if (handle == nullptr) { - std::cout << "Failed to call dlopen to load library '" << LIBRARY_NAME << "', error " << dlerror() << - std::endl; + std::cout << "Failed to call dlopen to load library '" << LIBRARY_NAME << "', error " << dlerror() + << std::endl; return -1; } @@ -55,9 +55,9 @@ public: } private: - constexpr static const char *LIBRARY_NAME = "lib_ock_ctr_common.so"; + constexpr static const char* LIBRARY_NAME = "lib_ock_ctr_common.so"; }; -} -} +} // namespace ctr +} // namespace ock -#endif // OCK_OCK_CTR_COMMON_DEF_H +#endif // OCK_OCK_CTR_COMMON_DEF_H diff --git a/src/core/ock_ctr_common/include/unique.h b/src/core/ock_ctr_common/include/unique.h index cb8960e7cadaad2e69305e266732453e0582d92b..5d11fe66aef0537b2237c2865e57fe8ef39607e2 100644 --- a/src/core/ock_ctr_common/include/unique.h +++ b/src/core/ock_ctr_common/include/unique.h @@ -59,6 +59,7 @@ using UniqueConf = struct UniqueConfCTR { uint32_t maxThreadNum = 8; // 最大工作线程数 int64_t maxIdVal = 0; // 最大id值 bool trace = false; // 是否开启性能检测,需要配合外部日志输出 + bool performance = false; // 是否开启增强接口,增强接口shardingNum必须是2的幂次方,默认用取模分桶 } __attribute__((packed)); using UniqueIn = struct UniqueInCTR { diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp deleted file mode 100644 index 36be19d91388084a36b11dac3f9a7ede8f50a410..0000000000000000000000000000000000000000 --- a/src/core/ssd_cache/cache_manager.cpp +++ /dev/null @@ -1,527 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include "cache_manager.h" - -#include -#include -#include - -#include "utils/common.h" -#include "utils/time_cost.h" - -using namespace MxRec; - -inline void CacheManager::GetExternalKeys(const absl::flat_hash_map &keyOffsetMap, - vector &externalKeys, vector &internalKeys, - const vector &keys) const -{ - for (const emb_key_t key : keys) { - if (keyOffsetMap.find(key) == keyOffsetMap.end()) { - externalKeys.emplace_back(key); - } else { - internalKeys.emplace_back(key); - } - } -} - -void CacheManager::AddDebugAndTraceLog(size_t batchKeySize, vector &externalKeys, - vector &externalSSDKeys) const -{ - LOG_DEBUG("TransferDDREmbWithSSD: batchKeySize:{}, externalKeys size:{}, externalSSDKeys size:{}", - batchKeySize, externalKeys.size(), externalSSDKeys.size()); - LOG_TRACE("TransferDDREmbWithSSD: externalKeys:{}, externalSSDKeys:{}", - VectorToString(externalKeys), VectorToString(externalSSDKeys)); -} - -/// 去重和过滤无效key -/// \param originalKeys 原有keys -/// \param keys 处理后的keys -void CacheManager::HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) const -{ - // 去重并保持原key的顺序 结果可测试 - unordered_set keySet; - for (auto& key : originalKeys) { - if (key == INVALID_KEY_VALUE) { - continue; - } - if (keySet.find(key) == keySet.end()) { - keySet.emplace(key); - keys.emplace_back(key); - } - } -} - -/// DDR与SSD数据转移,使DDR内剩余空间能放置当前批次key -/// \param embTableName emb表名 -/// \param embHashMap emb表 -/// \param originalKeys 当前批次key -/// \param channelId 通道id -/// \return 转移结果枚举 -TransferRet CacheManager::TransferDDREmbWithSSD(TableInfo& table, - const vector& originalKeys, int channelId) -{ - vector keys; // 去重和删除无效key - HandleRepeatAndInvalidKey(originalKeys, keys); - // 区分HBM+DDR内key,和HBM+DDR外的key(新key或保存在SSD中的key) - vector externalKeys; - vector internalKeys; - GetExternalKeys(table.keyOffsetMap, externalKeys, internalKeys, keys); - if (externalKeys.empty()) { return TransferRet::TRANSFER_OK; } - - // 判断剩余内存空间是否足够; 可用内存空间计算:HBM+DDR-已占用; 若是训练,再加DDR已淘汰; - // SSD仅与DDR交互,不考虑HBM淘汰位置;由于maxOffset比实际使用大1,所以虽然从0开始也不用再减1 - size_t ddrAvailableSize = table.devVocabSize + table.hostVocabSize - table.maxOffset; - if (channelId == TRAIN_CHANNEL_ID) { - ddrAvailableSize += table.evictHostPos.size(); - } - LOG_DEBUG("TransferDDREmbWithSSD, table:{}, maxOffset:{}, evictHostPos size:{}, ddrAvailableSize:{}", - table.name, table.maxOffset, table.evictHostPos.size(), ddrAvailableSize); - CreateSSDTableIfNotExist(table.name); - - // 调用ssdEngine查询当前批次key中保存在SSD中的key - vector externalSSDKeys; - GetSSDKeys(table.name, externalKeys, externalSSDKeys); - // 后续判断maxOffset是否超出范围时,maxOffset=devVocabSize+hostVocabSize时可用,此处包含等于 - bool isDDRSpaceEnough = ddrAvailableSize >= externalKeys.size(); - bool ddrSpaceEnoughOrEval = channelId != TRAIN_CHANNEL_ID || isDDRSpaceEnough; - if (ddrSpaceEnoughOrEval && externalSSDKeys.empty()) { - // 部分场景后续不用处理,在此处返回 - return TransferRet::TRANSFER_OK; - } - - AddDebugAndTraceLog(keys.size(), externalKeys, externalSSDKeys); - /* - * 前面 externalSSDKeys = 0 ,评估场景的 ddr空间可用、不可用已返回; 训练的可用已返回; - * 剩下的情况如下: - * 评估: - * externalSSDKeys > 0, 可用 & 不可用操作一样; - * 可选:Ddr->ssd, 腾出 externalSSDKeys 大小空间; - * Ssd->ddr, 需要移动 externalSSDKeys ; - * externalSSDKeys = 0 --已返回 - * 训练: - * externalSSDKeys > 0 - * 可用: - * 可选:Ddr->ssd, 腾出 externalSSDKeys 大小空间; - * Ssd->ddr, 需要移动 externalSSDKeys ; - * 不可用: - * 必选:Ddr->ssd, 腾出 externalKeys 大小空间; - * 需要计算ssd剩余空间:externalKeys - externalSSDKeys - * (注: 当前策略均转移externalKeys) - * Ssd->ddr, 需要移动 externalSSDKeys ; - * externalSSDKeys = 0 - * 可用: --已返回 - * 不可用: - * Ddr->ssd, 腾出 externalKeys 大小的空间; - * 需要计算ssd剩余空间: externalKeys - * 因cache每次只转移DDR最小空间,上述可选动作也需执行,避免SSD移入DDR时空间不足 - */ - // 训练场景检查SSD剩余空间 评估不考虑新key - if (channelId == TRAIN_CHANNEL_ID) { - size_t needSSDSize = externalKeys.size() - externalSSDKeys.size() - ddrAvailableSize; - const int64_t ssdAvailableSize = ssdEngine->GetTableAvailableSpace(table.name); - if (int64_t(needSSDSize) > ssdAvailableSize) { - LOG_ERROR("TransferDDREmbWithSSD: ssd available space is not enough to transfer DDR emb data. " - "needSSDSize:{}, ssdAvailableSize:{}", needSSDSize, ssdAvailableSize); - return TransferRet::SSD_SPACE_NOT_ENOUGH; - } - } - - // 从SSD获取emb数据并从SSD删除; 避免DDR->SSD时空间不够 - vector> ssdEmbData; - if (!externalSSDKeys.empty()) { - ssdEmbData = ssdEngine->FetchEmbeddings(table.name, externalSSDKeys); - ssdEngine->DeleteEmbeddings(table.name, externalSSDKeys); - } - - // 从ddr转移到ssd的key个数 - size_t ddrSwapOutSizeTmp = ddrSpaceEnoughOrEval ? externalSSDKeys.size() : externalKeys.size(); - auto ddrSwapOutSize = static_cast(ddrSwapOutSizeTmp - ddrAvailableSize); - LOG_DEBUG("TransferDDREmbWithSSD: ddrSwapOutSize:{}", ddrSwapOutSize); - - /* - * 转移DDR中数据到SSD - */ - // 记录要从DDR转移到SSD的key对应的offset(相对值,需减去devVocabSize) - vector ddrTransferPos; - TransferRet ddr2SsdRet = TransferDDREmb2SSD(table, ddrSwapOutSize, internalKeys, ddrTransferPos); - if (ddr2SsdRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { - ssdEngine->InsertEmbeddings(table.name, externalSSDKeys, ssdEmbData); - return ddr2SsdRet; - } - - HandleDDRTransferPos(ddrTransferPos, externalSSDKeys, table); - - /* - * 转移SSD中保存的当前批次key的emb数据到DDR - */ - return TransferSSDEmb2DDR(table, externalSSDKeys, ddrTransferPos, ssdEmbData); -} - -/// SSD数据转移到DDR中后刷新映射和频次信息 -/// \param embTableName emb表名 -/// \param embHashMap emb hash表 -/// \param externalSSDKeys 存储在SSD中的key列表 -/// \param ddrTransferPos -void CacheManager::RefreshRelateInfoWithSSD2DDR(TableInfo& table, - vector& externalSSDKeys, vector& ddrTransferPos) -{ - for (size_t i = 0; i < externalSSDKeys.size(); ++i) { - // 映射关系 ddrTransferPos是在ddrEmbHash中的位置,记录映射时需加上devVocabSize - auto& key = externalSSDKeys[i]; - table.keyOffsetMap[key] = ddrTransferPos[i] + table.devVocabSize; - // 频次 - ddrKeyFreqMap[table.name].PutWithInit(key, excludeDDRKeyCountMap[table.name][key]); - excludeDDRKeyCountMap[table.name].erase(key); - } -} - -void CacheManager::GetDDREmbInfo(vector& keys, TableInfo& table, - vector& ddrTransferPos, vector>& ddrEmbData) const -{ - // 根据offset 获取对应Emb数据 - for (auto& key : keys) { - auto koCast = static_cast(table.keyOffsetMap[key]); - ddrTransferPos.emplace_back(koCast - table.devVocabSize); - } - - LOG_TRACE("DDR keys:{}", VectorToString(keys)); - LOG_TRACE("DDR key positions:{}", VectorToString(ddrTransferPos)); - - ddrEmbData.resize(keys.size()); - const auto& emb = hostEmbs->GetEmb(table.name); -#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ddrTransferPos, emb, ddrEmbData) - for (size_t i = 0; i < ddrTransferPos.size(); ++i) { - auto& missingKeyPo = ddrTransferPos[i]; - const auto& src = emb.embData[missingKeyPo]; - ddrEmbData[i] = src; - } -} - -/// 使用ssdEmbData更新DDR内emb数据 -/// \param embTableName emb表名 -/// \param ddrTransferPos 需要更新的DDR内的offset -/// \param ssdEmbData SSD对应的emb数据 -void CacheManager::UpdateDDREmbInfo(const std::string& embTableName, - vector& ddrTransferPos, - vector>& ssdEmbData) const -{ - auto& emb = hostEmbs->GetEmb(embTableName); - auto& embData = emb.embData; -#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ddrTransferPos, embData, ssdEmbData) - for (size_t i = 0; i < ddrTransferPos.size(); ++i) { - embData[ddrTransferPos[i]] = ssdEmbData[i]; - } -} - -/// DDR_2_SSD场景数据刷新: 仅刷新映射和频次,ddr转移出去的offset信息后续统一处理 -/// \param embTableName emb表名 -/// \param embHashMap emb map -/// \param ddrSwapOutKeys 从DDR中转移到SSD中key列表 -/// \param ddrSwapOutCounts 从DDR中转移到SSD中key频次数据 -void CacheManager::RefreshRelateInfoWithDDR2SSD(TableInfo& table, - vector& ddrSwapOutKeys, - vector& ddrSwapOutCounts) -{ - auto& excludeFreqMap = excludeDDRKeyCountMap[table.name]; - for (size_t i = 0; i < ddrSwapOutKeys.size(); ++i) { - auto& key = ddrSwapOutKeys[i]; - table.keyOffsetMap.erase(key); - excludeFreqMap[key] = ddrSwapOutCounts[i]; - } -} - -/// key从DDR移入、移出、HBM淘汰时刷新频次信息;仅刷新频次信息 -/// \param embTableName emb表名 -/// \param keys 操作的key集合 -/// \param type TransferType -void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& keys, TransferType type) -{ - if (type == TransferType::DDR_2_HBM) { - for (auto& key : keys) { - // 频次数据记录到 excludeDDRKeyCountMap,并删除ddrKeyFreqMap中频次数据 - // 进入findOffset时记录的key次数 + ddr内key次数 - auto tmpCount = excludeDDRKeyCountMap[embTableName][key]; - excludeDDRKeyCountMap[embTableName][key] = ddrKeyFreqMap[embTableName].Get(key) + tmpCount; - ddrKeyFreqMap[embTableName].Pop(key); - } - } else if (type == TransferType::HBM_2_DDR) { - for (auto& key : keys) { - // excludeDDRKeyCountMap 中次数转移到 ddrKeyFreqMap, 并删除原记录 - ddrKeyFreqMap[embTableName].PutWithInit(key, excludeDDRKeyCountMap[embTableName][key]); - excludeDDRKeyCountMap[embTableName].erase(key); - } - } else if (type == TransferType::DDR_2_EVICT) { - for (auto& key : keys) { - ddrKeyFreqMap[embTableName].Pop(key); - } - } else { - // TransferType::HBM_2_EVICT - for (auto& key : keys) { - excludeDDRKeyCountMap[embTableName].erase(key); - } - } -} - -void CacheManager::Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo) -{ - this->hostEmbs = hostEmbPtr; - for (auto& emb : mgmtEmbInfo) { - EmbBaseInfo baseInfo {emb.ssdVocabSize, emb.ssdDataPath, false}; - embBaseInfos.emplace(emb.name, baseInfo); - ddrKeyFreqMap[emb.name]; - excludeDDRKeyCountMap[emb.name]; - } - ssdEngine->Start(); - LOG_INFO("CacheManager Init method end."); -} - -bool CacheManager::IsKeyInSSD(const string& embTableName, emb_key_t key) -{ - return ssdEngine->IsKeyExist(embTableName, key); -} - -/// 淘汰SSD中Emb信息 -/// \param embTableName emb表名 -/// \param keys 淘汰key列表 -void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& keys) -{ - if (keys.empty()) { - return; - } - // 1 删除缓存中记录的key的次数 2 删除SSD中保存的Emb数据 - for (auto& key : keys) { - excludeDDRKeyCountMap[embTableName].erase(key); - } - ssdEngine->DeleteEmbeddings(embTableName, keys); -} - -/// 放入key,新增/更新(次数+1)次数 -/// \param embTableName emb表名 -/// \param key key -/// \param type 记录类型 -void CacheManager::PutKey(const string& embTableName, const emb_key_t& key, RecordType type) -{ - if (type == RecordType::DDR) { - ddrKeyFreqMap[embTableName].Put(key); - return; - } - auto& hashMap = excludeDDRKeyCountMap[embTableName]; - const auto& it = hashMap.find(key); - freq_num_t count = it == hashMap.end() ? 1 : it->second + 1; - hashMap[key] = count; -} - -/// DDR->SSD与SSD->DDR的key个数可能不一致,手动补齐/截取 -/// \param ddrTransferPos DDR->SSD的offset列表(hostEmb表内的偏移值) -/// \param externalSSDKeys SSD->DDR的key列表 -/// \param embHashMap emb hash表 -void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, - TableInfo& table) -{ - if (ddrTransferPos.size() == externalSSDKeys.size()) { - return; - } - LOG_DEBUG("TransferDDREmbWithSSD: operate length is not equal, will padding or clipping, " - "ddrTransferPos size:{}, externalSSDKeys size:{}", - ddrTransferPos.size(), externalSSDKeys.size()); - // ddrTransferPos中是DDR内偏移位置,存入evictPos时,需加上devVocabSize;取出时需减去 - if (ddrTransferPos.size() > externalSSDKeys.size()) { - while (ddrTransferPos.size() > externalSSDKeys.size()) { - auto evictHostPos = ddrTransferPos.back() + table.devVocabSize; - table.evictHostPos.emplace_back(static_cast(evictHostPos)); - ddrTransferPos.pop_back(); - } - return; - } - // 补齐offset - while (ddrTransferPos.size() < externalSSDKeys.size() && !table.evictHostPos.empty()) { - ddrTransferPos.emplace_back(table.evictHostPos.back() - table.devVocabSize); - table.evictHostPos.pop_back(); - } - auto allSize = table.devVocabSize + table.hostVocabSize; - // 还不够继续使用maxOffset - while (ddrTransferPos.size() < externalSSDKeys.size() && table.maxOffset < allSize) { - auto nextPos = table.maxOffset++; - ddrTransferPos.emplace_back(nextPos - table.devVocabSize); - } - LOG_DEBUG("HandleDDRTransferPos: handle end, pos len:{}, keys len:{}", - ddrTransferPos.size(), externalSSDKeys.size()); -} - -void CacheManager::GetSSDKeys(const std::string& embTableName, vector& externalKeys, - vector& externalSSDKeys) -{ - for (auto& key : externalKeys) { - if (ssdEngine->IsKeyExist(embTableName, key)) { - externalSSDKeys.emplace_back(key); - } - } -} - -TransferRet CacheManager::TransferDDREmb2SSD(TableInfo& table, - int64_t ddrSwapOutSize, - const vector& keys, vector& ddrTransferPos) -{ - if (ddrSwapOutSize <= 0) { - // 此时不需要转移数据 - return TransferRet::TRANSFER_OK; - } - - TimeCost ddr2SsdTc; - LOG_DEBUG("TransferDDREmbWithSSD: get ddr least freq keys, table:{}, ddrSwapOutSize:{}", - table.name, ddrSwapOutSize); - // 获取DDR中指定数量的最低频次key,并获取相应emb数据,执行DDR换出到SSD - vector ddrSwapOutKeys; - vector ddrSwapOutCounts; - ddrKeyFreqMap[table.name].GetAndDeleteLeastFreqKeyInfo(ddrSwapOutSize, keys, ddrSwapOutKeys, ddrSwapOutCounts); - if (static_cast(ddrSwapOutKeys.size()) != ddrSwapOutSize) { - auto keyTableSize = ddrKeyFreqMap[table.name].keyTable.size(); - // 获取的最低频次key数量和预期不一致,DDR空间不足,不能放置当前批次数据 - LOG_ERROR("TransferDDREmbWithSSD, table:{}, vector length is not equal, ddrSwapOutKeys size:{}, " - "ddrSwapOutSize:{}, ddr lfu keyTable size:{}", - table.name, ddrSwapOutKeys.size(), ddrSwapOutSize, keyTableSize); - RestoreLeastFreqInfo(table.name, ddrSwapOutKeys, ddrSwapOutCounts); - return TransferRet::DDR_SPACE_NOT_ENOUGH; - } - LOG_DEBUG("TransferDDREmbWithSSD: get DDR embeddings and save to SSD, table:{}, size:{}", - table.name, ddrSwapOutKeys.size()); - // 获取DDR中emb数据 - vector> ddrEmbData; - GetDDREmbInfo(ddrSwapOutKeys, table, ddrTransferPos, ddrEmbData); - // 调用SSDEngine接口,将DDR Emb数据保存到SSD - ssdEngine->InsertEmbeddings(table.name, ddrSwapOutKeys, ddrEmbData); - - // 初始化DDR内被转移出去的位置 - hostEmbs->EvictInitEmb(table.name, ddrTransferPos); - - // 更新记录的DDR中key频次信息 - RefreshRelateInfoWithDDR2SSD(table, ddrSwapOutKeys, ddrSwapOutCounts); - LOG_DEBUG("TransferDDREmbWithSSD: table:{}, ddr2SsdTc TimeCost(ms):{}", table.name, ddr2SsdTc.ElapsedMS()); - return TransferRet::TRANSFER_OK; -} - -TransferRet CacheManager::TransferSSDEmb2DDR(TableInfo& table, - vector& externalSSDKeys, vector& ddrTransferPos, - vector>& ssdEmbData) -{ - if (externalSSDKeys.empty()) { - return TransferRet::TRANSFER_OK; - } - TimeCost ssd2DdrTc; - LOG_DEBUG("TransferDDREmbWithSSD: get SSD embeddings and save to DDR, size:{}", externalSSDKeys.size()); - if (ddrTransferPos.size() != externalSSDKeys.size() || externalSSDKeys.size() != ssdEmbData.size()) { - LOG_ERROR("TransferDDREmbWithSSD, vector length is not equal, ddrTransferPos len:{}, externalSSDKeys len:{}, " - "ssdEmbData len:{}", ddrTransferPos.size(), externalSSDKeys.size(), ssdEmbData.size()); - return TransferRet::TRANSFER_ERROR; - } - // 将SSD emb存储到DDR中 刷新频次信息 - UpdateDDREmbInfo(table.name, ddrTransferPos, ssdEmbData); - RefreshRelateInfoWithSSD2DDR(table, externalSSDKeys, ddrTransferPos); - LOG_DEBUG("TransferDDREmbWithSSD: ssd2DdrTc TimeCost(ms):{}", ssd2DdrTc.ElapsedMS()); - return TransferRet::TRANSFER_OK; -} - -void CacheManager::CreateSSDTableIfNotExist(const std::string& embTableName) -{ - if (embBaseInfos[embTableName].isExist) { - return; - } - if (!ssdEngine->IsTableExist(embTableName)) { - ssdEngine->CreateTable(embTableName, embBaseInfos[embTableName].savePath, - embBaseInfos[embTableName].maxTableSize); - embBaseInfos[embTableName].isExist = true; - LOG_INFO("create ssd table end, embTableName:" + embTableName); - return; - } - // 续训场景:embBaseInfos 没有保存,不会初始化;SSD表会初始化,此时表已存在 - embBaseInfos[embTableName].isExist = true; - LOG_INFO("ssd table is exist, embTableName:" + embTableName); -} - -void CacheManager::RestoreLeastFreqInfo(const std::string& embTableName, vector& ddrSwapOutKeys, - vector& ddrSwapOutCounts) -{ - auto& lfuCache = ddrKeyFreqMap[embTableName]; - for (size_t i = 0; i < ddrSwapOutKeys.size(); ++i) { - lfuCache.PutWithInit(ddrSwapOutKeys[i], ddrSwapOutCounts[i]); - } -} - -CacheManager::~CacheManager() -{ - hostEmbs = nullptr; - ssdEngine->Stop(); - ddrKeyFreqMap.clear(); - excludeDDRKeyCountMap.clear(); -} - -/// 加载数据到CacheManager -/// \param ddrFreqInitMap ddr内key频次数据 -/// \param excludeDdrFreqInitMap 非DDR key频次数据 -/// \param step 加载SSDEngine传入步数 -void CacheManager::Load(unordered_map>& ddrFreqInitMap, - unordered_map>& excludeDdrFreqInitMap, - int step, int rankSize, int rankId) -{ - if (rankSize <= 0) { - throw runtime_error("rank size must > 0"); - } - // 加载CacheManager数据 - for (auto& it : ddrFreqInitMap) { - auto& embTableName = it.first; - auto& freqMap = it.second; - for (auto& freqIt : freqMap) { - if (freqIt.first % rankSize != rankId) { - continue; - } - ddrKeyFreqMap[embTableName].PutWithInit(freqIt.first, freqIt.second); - } - } - for (auto& it : excludeDdrFreqInitMap) { - auto& embTableName = it.first; - auto& freqMap = it.second; - for (auto& freqIt : freqMap) { - if (freqIt.first % rankSize != rankId) { - continue; - } - excludeDDRKeyCountMap[embTableName].emplace(freqIt.first, freqIt.second); - } - } - // 加载SSDEngine数据 -#ifndef GTEST - for (auto& it : embBaseInfos) { - string embTableName = it.first; - EmbBaseInfo& embBase = it.second; - ssdEngine->Load(embTableName, embBase.savePath, embBase.maxTableSize, step); - } -#endif -} - -void CacheManager::SaveSSDEngine(int step) -{ -#ifndef GTEST - ssdEngine->Save(step); -#endif -} - -int64_t CacheManager::GetTableEmbeddingSize(const string& tableName) -{ - if (ssdEngine == nullptr) { - throw runtime_error("SSDEngine not init"); - } - return ssdEngine->GetTableEmbeddingSize(tableName); -} - diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index 83395f362a4f1d93f9ee8234f5fa48bdbe10401f..8c7da24e418a83ca2e71ffa1e1ad1d71915116a8 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -24,7 +24,7 @@ using namespace MxRec; /// 创建新文件实例,包含元数据文件、数据文件 /// \param fileID 文件ID /// \param fileDir 当前文件目录 -File::File(uint64_t fileID, string &fileDir) : fileID(fileID), fileDir(fileDir) +File::File(uint64_t fileID, string& fileDir) : fileID(fileID), fileDir(fileDir) { LOG_DEBUG("start init file, fileID:{}", fileID); @@ -75,7 +75,7 @@ File::File(uint64_t fileID, string &fileDir) : fileID(fileID), fileDir(fileDir) /// \param loadDir 加载文件的目录 /// \param fileDir 当前文件目录 /// \param step 加载的步数 -File::File(uint64_t fileID, string &fileDir, string &loadDir, int step) : fileID(fileID), fileDir(fileDir) +File::File(uint64_t fileID, string& fileDir, string& loadDir, int step) : fileID(fileID), fileDir(fileDir) { LOG_DEBUG("start init file with load, fileID:{}", fileID); @@ -141,13 +141,13 @@ File::~File() fs::remove(dataFilePath); } -bool File::IsKeyExist(emb_key_t key) +bool File::IsKeyExist(emb_cache_key_t key) const { auto it = keyToOffset.find(key); return !(it == keyToOffset.end()); } -void File::InsertEmbeddings(vector &keys, vector> &embeddings) +void File::InsertEmbeddings(vector& keys, vector>& embeddings) { if (keys.size() != embeddings.size()) { throw invalid_argument("keys' length not equal to embeddings' length"); @@ -178,10 +178,10 @@ void File::InsertEmbeddings(vector &keys, vector> &embe dataCnt += dLen; } -vector> File::FetchEmbeddings(vector &keys) +vector> File::FetchEmbeddings(vector& keys) { vector> ret; - for (emb_key_t k: keys) { + for (emb_cache_key_t k: keys) { auto it = keyToOffset.find(k); if (it == keyToOffset.end()) { throw invalid_argument("key not exist"); @@ -208,7 +208,7 @@ vector> File::FetchEmbeddings(vector &keys) return ret; } -void File::DeleteEmbedding(emb_key_t key) +void File::DeleteEmbedding(emb_cache_key_t key) { if (!IsKeyExist(key)) { return; @@ -217,7 +217,7 @@ void File::DeleteEmbedding(emb_key_t key) staleDataCnt += 1; } -void File::Save(const string &saveDir, int step) +void File::Save(const string& saveDir, int step) { LOG_DEBUG("start save file at step:{}, fileID:{}", step, fileID); @@ -278,15 +278,15 @@ void File::Load() { // file already validate and open in instantiation LOG_DEBUG("start reading meta file, fileID:{}", fileID); - emb_key_t key; + emb_cache_key_t key; offset_t offset; do { - localFileMeta.read(reinterpret_cast(&key), keyDataLen); + localFileMeta.read(reinterpret_cast(&key), KEY_DATA_LEN); if (!localFileMeta.eof() && localFileMeta.fail()) { throw invalid_argument("file broken while reading key"); } - localFileMeta.read(reinterpret_cast(&offset), offsetDataLen); + localFileMeta.read(reinterpret_cast(&offset), OFFSET_DATA_LEN); if (!localFileMeta.eof() && localFileMeta.fail()) { throw invalid_argument("file broken while reading offset"); } @@ -311,9 +311,9 @@ void File::Load() LOG_DEBUG("end reading meta file, fileID:{}", fileID); } -vector File::GetKeys() +vector File::GetKeys() { - vector ret; + vector ret; for (auto item: keyToOffset) { ret.push_back(item.first); } @@ -334,3 +334,40 @@ uint64_t File::GetStaleDataCnt() const { return staleDataCnt; } + +void File::InsertEmbeddingsByAddr(vector& keys, vector& embeddingsAddr, + uint64_t extEmbeddingSize) +{ + if (keys.size() != embeddingsAddr.size()) { + throw invalid_argument("keys' length not equal to embeddings' length"); + } + + size_t dLen = keys.size(); + for (size_t i = 0; i < dLen; ++i) { + if (embeddingsAddr[i] == nullptr) { + throw invalid_argument("Null pointer found in embeddingsAddr"); + } + } + + localFileData.seekp(lastWriteOffset); // always set pointer to buffer end in case reading happened before + + for (size_t i = 0; i < dLen; ++i) { + if (IsKeyExist(keys[i])) { + staleDataCnt++; + } + keyToOffset[keys[i]] = lastWriteOffset; + + if (extEmbeddingSize > maxEmbSize) { + throw invalid_argument("embedding size too large"); + } + localFileData.write(reinterpret_cast(&extEmbeddingSize), sizeof(extEmbeddingSize)); + localFileData.write(reinterpret_cast(embeddingsAddr[i]), extEmbeddingSize * sizeof(float)); + + auto pos = localFileData.tellp(); + if (pos == -1) { + throw runtime_error("can't get file position pointer, write data failed"); + } + lastWriteOffset = offset_t(pos); + } + dataCnt += dLen; +} diff --git a/src/core/ssd_engine/file.h b/src/core/ssd_engine/file.h index 949859db078c7ab35b85e94e0feada5ebfde3d8f..5789ab8b262deb337db3856026fe5a9aaa26a8b7 100644 --- a/src/core/ssd_engine/file.h +++ b/src/core/ssd_engine/file.h @@ -33,30 +33,31 @@ namespace MxRec { using offset_t = uint32_t; class File { - static const uint64_t keyDataLen = sizeof(emb_key_t); - static const uint64_t offsetDataLen = sizeof(offset_t); + static constexpr uint64_t KEY_DATA_LEN = sizeof(emb_cache_key_t); + static constexpr uint64_t OFFSET_DATA_LEN = sizeof(offset_t); public: - File(uint64_t fileID, string &fileDir); + File(uint64_t fileID, string& fileDir); - File(uint64_t fileID, string &fileDir, string &loadDir, int step); // initialize with loading specific step data + File(uint64_t fileID, string& fileDir, string& loadDir, + int step); // initialize with loading specific step data File(const File&) = delete; File& operator=(const File&) = delete; ~File(); - bool IsKeyExist(emb_key_t key); + bool IsKeyExist(emb_cache_key_t key) const; - void InsertEmbeddings(vector &keys, vector> &embeddings); + void InsertEmbeddings(vector& keys, vector>& embeddings); - vector> FetchEmbeddings(vector &keys); + vector> FetchEmbeddings(vector& keys); - void DeleteEmbedding(emb_key_t key); + void DeleteEmbedding(emb_cache_key_t key); - void Save(const string &saveDir, int step); + void Save(const string& saveDir, int step); - vector GetKeys(); + vector GetKeys(); uint64_t GetDataCnt() const; @@ -64,6 +65,9 @@ namespace MxRec { uint64_t GetStaleDataCnt() const; + void InsertEmbeddingsByAddr(vector& keys, vector& embeddingsAddr, + uint64_t extEmbeddingSize); + private: uint64_t fileID; // init by constructor string fileDir; // init by constructor @@ -77,7 +81,7 @@ namespace MxRec { uint64_t dataCnt = 0; uint64_t staleDataCnt = 0; - unordered_map keyToOffset{}; // offset_t >> maxDataNumInFile * embDataSize + unordered_map keyToOffset{}; // offset_t >> maxDataNumInFile * embDataSize offset_t lastWriteOffset = 0; void Load(); diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp index 6570879250a737876a3661602faf9316a00f5624..3f0b3a1c0a049996041df125aeafb39a859b1616 100644 --- a/src/core/ssd_engine/ssd_engine.cpp +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -27,7 +27,7 @@ bool SSDEngine::IsTableExist(const string &tableName) return !(it == tableMap.end()); } -bool SSDEngine::IsKeyExist(const string &tableName, emb_key_t key) +bool SSDEngine::IsKeyExist(const string &tableName, emb_cache_key_t key) { if (!isRunning) { throw runtime_error("SSDEngine not running"); @@ -54,7 +54,8 @@ void SSDEngine::CreateTable(const string &tableName, vector savePaths, u tableMap[tableName] = make_shared(tableName, savePaths, maxTableSize, compactThreshold); } -void SSDEngine::InsertEmbeddings(const string &tableName, vector &keys, vector> &embeddings) +void SSDEngine::InsertEmbeddings(const string& tableName, vector& keys, + vector>& embeddings) { if (!isRunning) { throw runtime_error("SSDEngine not running"); @@ -71,7 +72,7 @@ void SSDEngine::InsertEmbeddings(const string &tableName, vector &key it->second->InsertEmbeddings(keys, embeddings); } -void SSDEngine::DeleteEmbeddings(const string &tableName, vector &keys) +void SSDEngine::DeleteEmbeddings(const string &tableName, vector &keys) { if (!isRunning) { throw runtime_error("SSDEngine not running"); @@ -102,9 +103,16 @@ void SSDEngine::Save(int step) if (!isRunning) { throw runtime_error("SSDEngine not running"); } + + if (step == loadStep) { + LOG_INFO("save step equal to load step, skip saving, step:{}", step); + return; + } + for (auto item: as_const(tableMap)) { item.second->Save(step); } + saveStep = step; } void SSDEngine::Load(const string &tableName, vector savePaths, uint64_t maxTableSize, int step) @@ -112,12 +120,19 @@ void SSDEngine::Load(const string &tableName, vector savePaths, uint64_t if (!isRunning) { throw runtime_error("SSDEngine not running"); } + + if (step == saveStep) { + LOG_INFO("load step equal to save step, skip loading, step:{}", step); + return; + } + auto it = as_const(tableMap).find(tableName); if (it != tableMap.end()) { throw invalid_argument("table already exist"); } tableMap[tableName] = make_shared
(tableName, savePaths, maxTableSize, compactThreshold, step); + loadStep = step; } void SSDEngine::Start() @@ -154,7 +169,7 @@ void SSDEngine::CompactMonitor() LOG_DEBUG("SSDEngine end CompactMonitor"); } -vector> SSDEngine::FetchEmbeddings(const string &tableName, vector &keys) +vector> SSDEngine::FetchEmbeddings(const string &tableName, vector &keys) { if (!isRunning) { throw runtime_error("SSDEngine not running"); @@ -198,7 +213,7 @@ void SSDEngine::SetCompactThreshold(double threshold) throw invalid_argument("compact threshold should in range [0, 1]"); } -int64_t SSDEngine::GetTableEmbeddingSize(const string &tableName) +int64_t SSDEngine::GetTableUsage(const string &tableName) { if (!isRunning) { throw runtime_error("SSDEngine not running"); @@ -209,3 +224,30 @@ int64_t SSDEngine::GetTableEmbeddingSize(const string &tableName) } return static_cast(it->second->GetTableUsage()); } + +void SSDEngine::InsertEmbeddingsByAddr(const string& tableName, vector& keys, + vector& embeddingsAddr, uint64_t extEmbeddingSize) +{ + if (!isRunning) { + throw runtime_error("SSDEngine not running"); + } + auto it = as_const(tableMap).find(tableName); + if (it == tableMap.end()) { + throw invalid_argument("table not found"); + } + + if (keys.size() != embeddingsAddr.size()) { + throw invalid_argument("keys' length not equal to embeddings' length"); + } + + it->second->InsertEmbeddingsByAddr(keys, embeddingsAddr, extEmbeddingSize); +} + +vector>> SSDEngine::ExportTableKey() +{ + vector>> tableKeysVec; + for (const auto& p : tableMap) { + tableKeysVec.emplace_back(p.first, p.second->ExportKeys()); + } + return tableKeysVec; +} diff --git a/src/core/ssd_engine/ssd_engine.h b/src/core/ssd_engine/ssd_engine.h index 10f89d5754e440d548eb585701389b5f235a4f2e..942318c418567c6ebd39e66e0ff0a4f4914891a8 100644 --- a/src/core/ssd_engine/ssd_engine.h +++ b/src/core/ssd_engine/ssd_engine.h @@ -22,26 +22,27 @@ See the License for the specific language governing permissions and #include #include -#include "utils/common.h" +#include "l3_storage/l3_storage.h" namespace MxRec { - class SSDEngine { + class SSDEngine : public L3Storage { public: bool IsTableExist(const string &tableName); - bool IsKeyExist(const string &tableName, emb_key_t key); + bool IsKeyExist(const string &tableName, emb_cache_key_t key); void CreateTable(const string &tableName, vector savePaths, uint64_t maxTableSize); int64_t GetTableAvailableSpace(const string &tableName); - void InsertEmbeddings(const string &tableName, vector &keys, vector> &embeddings); + void InsertEmbeddings(const string &tableName, vector &keys, + vector> &embeddings); - void DeleteEmbeddings(const string &tableName, vector &keys); + void DeleteEmbeddings(const string &tableName, vector &keys); - vector> FetchEmbeddings(const string &tableName, vector &keys); + vector> FetchEmbeddings(const string &tableName, vector &keys); void Save(int step); @@ -55,7 +56,12 @@ namespace MxRec { void SetCompactThreshold(double threshold); - int64_t GetTableEmbeddingSize(const string& tableName); + int64_t GetTableUsage(const string& tableName); + + void InsertEmbeddingsByAddr(const string &tableName, vector &keys, + vector &embeddingsAddr, uint64_t extEmbeddingSize); + + vector>> ExportTableKey(); private: bool isRunning = false; @@ -68,6 +74,9 @@ namespace MxRec { shared_ptr compactThread = nullptr; void CompactMonitor(); + + int loadStep = -1; + int saveStep = -1; }; } diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index c7ed5363162688b42df75e4969f6c2c1d5041559..9e48b0ef76715d1f4c05b5d77ec5ce1fd3ecded9 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -72,27 +72,27 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize LOG_INFO("load table:{} done. try store at path:{}", name, curTablePath); } -bool Table::IsKeyExist(emb_key_t key) +bool Table::IsKeyExist(emb_cache_key_t key) { lock_guard guard(rwLock); auto it = keyToFile.find(key); return !(it == keyToFile.end()); } -void Table::InsertEmbeddings(vector &keys, vector> &embeddings) +void Table::InsertEmbeddings(vector &keys, vector> &embeddings) { lock_guard guard(rwLock); InsertEmbeddingsInner(keys, embeddings); } -vector> Table::FetchEmbeddings(vector &keys) +vector> Table::FetchEmbeddings(vector &keys) { lock_guard guard(rwLock); return FetchEmbeddingsInner(keys); } -void Table::DeleteEmbeddings(vector &keys) +void Table::DeleteEmbeddings(vector &keys) { lock_guard guard(rwLock); DeleteEmbeddingsInner(keys); @@ -137,7 +137,7 @@ void Table::Save(int step) SetTablePathToDiskWithSpace(); } catch (runtime_error &e) { metaFile.close(); - throw runtime_error(StringFormat("set table path to disk with space error:{}", e.what())); + throw runtime_error(StringFormat("set table path to disk with space error:%s", e.what())); } try { CreateTableDir(curTablePath); @@ -205,7 +205,7 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) throw invalid_argument("table size too small, key quantity exceed while loading data"); } - for (emb_key_t k: keys) { + for (emb_cache_key_t k: keys) { if (keyToFile.find(k) != keyToFile.end()) { throw invalid_argument( "find duplicate key in files, compaction already done before saving, file may broken or modified"); @@ -258,7 +258,7 @@ void Table::Load(const string &metaFilePath, int step) LoadDataFileSet(metaFile, step); } catch (exception &e) { metaFile->close(); - throw runtime_error(StringFormat("load data file set error:{}", e.what())); + throw runtime_error(StringFormat("load data file set error: %s", e.what())); } metaFile->close(); if (metaFile->fail()) { @@ -267,7 +267,7 @@ void Table::Load(const string &metaFilePath, int step) LOG_INFO("table:{}, end load data file", name); } -void Table::InsertEmbeddingsInner(vector &keys, vector> &embeddings) +void Table::InsertEmbeddingsInner(vector &keys, vector> &embeddings) { if (totalKeyCnt > maxTableSize) { throw invalid_argument("table size too small, key quantity exceed while loading data"); @@ -281,7 +281,7 @@ void Table::InsertEmbeddingsInner(vector &keys, vector> curMaxFileID++; } - for (emb_key_t k: keys) { + for (emb_cache_key_t k: keys) { auto it = keyToFile.find(k); if (it != keyToFile.end()) { it->second->DeleteEmbedding(k); @@ -294,25 +294,25 @@ void Table::InsertEmbeddingsInner(vector &keys, vector> totalKeyCnt += keys.size(); } -vector> Table::FetchEmbeddingsInner(vector &keys) +vector> Table::FetchEmbeddingsInner(vector &keys) { // build mini batch for each file, first element for keys, second for index size_t dLen = keys.size(); - unordered_map, shared_ptr, vector>>> miniBatch; + unordered_map, shared_ptr, vector>>> miniBatch; for (size_t i = 0; i < dLen; ++i) { auto it = as_const(keyToFile).find(keys[i]); if (it == keyToFile.end()) { throw invalid_argument(StringFormat("failed to find the key, {key=%d} not exist!", keys[i])); } if (miniBatch[it->second] == nullptr) { - miniBatch[it->second] = make_shared, vector>>(); + miniBatch[it->second] = make_shared, vector>>(); } miniBatch[it->second]->first.emplace_back(keys[i]); miniBatch[it->second]->second.emplace_back(i); } // must convert map to list to perform parallel query, omp not support to iterate map - vector, vector, vector>> queryList; + vector, vector, vector>> queryList; queryList.reserve(miniBatch.size()); for (auto [f, info]: miniBatch) { queryList.emplace_back(f, info->first, info->second); @@ -368,7 +368,7 @@ void Table::Compact(bool fullCompact) for (const auto &f: compactFileList) { staleDataFileSet.erase(f); fileSet.erase(f); - vector validKeys = f->GetKeys(); + vector validKeys = f->GetKeys(); vector> validEmbs = f->FetchEmbeddings(validKeys); InsertEmbeddingsInner(validKeys, validEmbs); } @@ -381,9 +381,9 @@ uint64_t Table::GetTableAvailableSpace() return maxTableSize - totalKeyCnt; } -void Table::DeleteEmbeddingsInner(vector &keys) +void Table::DeleteEmbeddingsInner(vector &keys) { - for (emb_key_t k: keys) { + for (emb_cache_key_t k: keys) { auto it = keyToFile.find(k); if (it != keyToFile.end()) { it->second->DeleteEmbedding(k); @@ -441,3 +441,46 @@ void Table::CreateTableDir(const string &path) LOG_DEBUG("create table dir:{}", path); } +void Table::InsertEmbeddingsByAddr(vector& keys, vector& embeddingsAddr, + uint32_t extEmbeddingSize) +{ + lock_guard guard(rwLock); + InsertEmbeddingsByAddrInner(keys, embeddingsAddr, extEmbeddingSize); +} + +void Table::InsertEmbeddingsByAddrInner(vector& keys, vector& embeddingsAddr, + uint64_t extEmbeddingSize) +{ + if (totalKeyCnt > maxTableSize) { + throw invalid_argument("table size too small, key quantity exceed while loading data"); + } + + if (curFile == nullptr || (curFile != nullptr && curFile->GetDataCnt() >= maxDataNumInFile)) { + SetTablePathToDiskWithSpace(); + CreateTableDir(curTablePath); + curFile = make_shared(curMaxFileID, curTablePath); + fileSet.insert(curFile); + curMaxFileID++; + } + + for (emb_cache_key_t k : keys) { + auto it = keyToFile.find(k); + if (it != keyToFile.end()) { + it->second->DeleteEmbedding(k); + staleDataFileSet.insert(it->second); + totalKeyCnt -= 1; + } + keyToFile[k] = curFile; + } + curFile->InsertEmbeddingsByAddr(keys, embeddingsAddr, extEmbeddingSize); + totalKeyCnt += keys.size(); +} + +vector Table::ExportKeys() +{ + vector vec; + for (const auto& p : keyToFile) { + vec.push_back(p.first); + } + return vec; +} \ No newline at end of file diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index 87fa6f35c2e845dfb90b49891b91c1c43c6e7d65..c34837dc30511d5b476ec950955324800870b13d 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -32,18 +32,18 @@ namespace MxRec { class Table { public: - Table(const string &name, vector &savePaths, uint64_t maxTableSize, double compactThreshold); + Table(const string& name, vector& savePaths, uint64_t maxTableSize, double compactThreshold); // initialize with loading specific step data - Table(const string &name, vector &saveDirs, uint64_t maxTableSize, double compactThreshold, int step); + Table(const string& name, vector& saveDirs, uint64_t maxTableSize, double compactThreshold, int step); - bool IsKeyExist(emb_key_t key); + bool IsKeyExist(emb_cache_key_t key); - void InsertEmbeddings(vector &keys, vector> &embeddings); + void InsertEmbeddings(vector& keys, vector>& embeddings); - vector> FetchEmbeddings(vector &keys); + vector> FetchEmbeddings(vector& keys); - void DeleteEmbeddings(vector &keys); + void DeleteEmbeddings(vector& keys); void Save(int step); @@ -53,26 +53,34 @@ namespace MxRec { uint64_t GetTableUsage(); + void InsertEmbeddingsByAddr(vector& keys, vector& embeddingsAddr, + uint32_t extEmbeddingSize); + + vector ExportKeys(); + private: static void CreateTableDir(const string& path); void Load(const string& metaFilePath, int step); - void InsertEmbeddingsInner(vector &keys, vector> &embeddings); + void InsertEmbeddingsInner(vector& keys, vector>& embeddings); - void DeleteEmbeddingsInner(vector &keys); + void DeleteEmbeddingsInner(vector& keys); - vector> FetchEmbeddingsInner(vector &keys); + vector> FetchEmbeddingsInner(vector& keys); void LoadDataFileSet(const shared_ptr& metaFile, int step); void SetTablePathToDiskWithSpace(); + void InsertEmbeddingsByAddrInner(vector& keys, vector& embeddingsAddr, + uint64_t extEmbeddingSize); + string name; // init by constructor vector savePaths; // init by constructor, support Save and Load from multiple path uint64_t maxTableSize; // init by constructor, maximum key-value volume uint64_t totalKeyCnt = 0; - unordered_map> keyToFile{}; // max mem cost 1.5G*2 for 100m keys + unordered_map> keyToFile{}; // max mem cost 1.5G*2 for 100m keys set> staleDataFileSet{}; string curTablePath = ""; uint32_t curSavePathIdx = 0; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 38e64444cb9e599936b9058c595d8431ff911d1b..15aa69bbb54642eab55450b6b68da1425de8e25e 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -20,6 +20,7 @@ See the License for the specific language governing permissions and #include #include #include +#include #include @@ -37,6 +38,8 @@ namespace MxRec { int GlogConfig::gGlogLevel; string GlogConfig::gRankId; + ock::ctr::FactoryPtr factory {}; + RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, const vector& ctrlSteps) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), ctrlSteps(ctrlSteps) { @@ -45,8 +48,8 @@ namespace MxRec { localRankId = rankId % localRankSize; } useStatic = static_cast(option) bitand HybridOption::USE_STATIC; - useHot = static_cast(option) bitand HybridOption::USE_HOT; useDynamicExpansion = static_cast(option) bitand HybridOption::USE_DYNAMIC_EXPANSION; + useSumSameIdGradients = static_cast(option) bitand HybridOption::USE_SUM_SAME_ID_GRADIENTS; } RankInfo::RankInfo(int localRankSize, int option, const vector& maxStep) @@ -58,7 +61,6 @@ namespace MxRec { localRankId = rankId % localRankSize; } useStatic = static_cast(option) & HybridOption::USE_STATIC; - useHot = static_cast(option) & HybridOption::USE_HOT; } RandomInfo::RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax) @@ -148,10 +150,40 @@ namespace MxRec { return true; } + std::string FloatPtrToLimitStr(float* ptr, const size_t& prtSize) + { + constexpr size_t maxDispLen = 10; // max display number + int maxLen = static_cast(std::min(prtSize, maxDispLen)); + std::string s; + for (int i = 0; i < maxLen; i++) { + s += std::to_string(*(ptr + i)) + " "; + } + return s; + } + ostream& operator<<(ostream& ss, MxRec::CkptDataType type) { ss << static_cast(type); return ss; } + int GetStepFromPath(const string& loadPath) + { + regex pattern(SAVE_SPARSE_PATH_PREFIX + "-.*-(\\d+)"); + smatch match; + if (!regex_search(loadPath, match, pattern)) { + return 0; + } + int res = 0; + unsigned int minSize = 2; + if (match.size() < minSize) { + return res; + } + try { + res = stoi(match[1]); + } catch (const std::invalid_argument& e) { + LOG_ERROR("argument is invalid: {}", e.what()); + } + return res; + } } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index f6c3de3fb5f29734772fbe1bdb044c8fbfda7dbe..8c7528f4fbd7119bf086475dc1ee65bc3cc37cc6 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -17,601 +17,639 @@ See the License for the specific language governing permissions and #define COMMON_H #include -#include -#include -#include + +#include #include +#include +#include #include +#include #include -#include -#include -#include "tensorflow/core/framework/tensor.h" -#include "absl/container/flat_hash_map.h" -#include "securec.h" -#include "utils/logger.h" -#include "utils/config.h" +#include -#include "initializer/initializer.h" +#include "absl/container/flat_hash_map.h" #include "initializer/constant_initializer/constant_initializer.h" -#include "initializer/truncated_normal_initializer/truncated_normal_initializer.h" +#include "initializer/initializer.h" #include "initializer/random_normal_initializer/random_normal_initializer.h" +#include "initializer/truncated_normal_initializer/truncated_normal_initializer.h" +#include "ock_ctr_common/include/embedding_cache.h" +#include "ock_ctr_common/include/factory.h" +#include "securec.h" +#include "tensorflow/core/framework/tensor.h" +#include "utils/config.h" +#include "utils/logger.h" #if defined(BUILD_WITH_EASY_PROFILER) - #include - #include +#include +#include #else - #define EASY_FUNCTION(...) - #define EASY_VALUE(...) - #define EASY_BLOCK(...) - #define EASY_END_BLOCK - #define EASY_PROFILER_ENABLE - #define EASY_PROFILER_DISABLE +#define EASY_FUNCTION(...) +#define EASY_VALUE(...) +#define EASY_BLOCK(...) +#define EASY_END_BLOCK +#define EASY_PROFILER_ENABLE +#define EASY_PROFILER_DISABLE #endif namespace MxRec { #define INFO_PTR shared_ptr #define MGMT_CPY_THREADS 4 #define PROFILING - using namespace tensorflow; - constexpr int TRAIN_CHANNEL_ID = 0; - constexpr int EVAL_CHANNEL_ID = 1; - - constexpr int MAX_CHANNEL_NUM = 2; - constexpr int MAX_KEY_PROCESS_THREAD = 10; - constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; - constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; - constexpr int KEY_PROCESS_THREAD = 6; - constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; - constexpr size_t MAX_VOCABULARY_SIZE = 1e10; - constexpr int SSD_SIZE_INDEX = 2; - constexpr int MAX_FILE_NUM = 1000; - // for GLOG - struct GlogConfig { - static bool gStatOn; - static int gGlogLevel; - static string gRankId; - }; - - constexpr int GLOG_MAX_BUF_SIZE = 1024; - constexpr int GLOG_TIME_WIDTH_2 = 2; - constexpr int GLOG_TIME_WIDTH_6 = 6; - constexpr char GLOG_STAT_FLAG[] = "statOn"; - - // unique related config - constexpr int UNIQUE_BUCKET = 6; - constexpr int MIN_UNIQUE_THREAD_NUM = 1; - - // validate file - constexpr long long FILE_MAX_SIZE = 1LL << 40; - constexpr int FILE_MIN_SIZE = 0; - constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; - constexpr size_t MAP_BYTE_SIZE{static_cast(10) * 1024 * 1024 * 1024}; +using namespace tensorflow; +extern ock::ctr::FactoryPtr factory; +constexpr int TRAIN_CHANNEL_ID = 0; +constexpr int EVAL_CHANNEL_ID = 1; + +constexpr int MAX_CHANNEL_NUM = 2; +constexpr int MAX_KEY_PROCESS_THREAD = 10; +constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; +constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; +constexpr int KEY_PROCESS_THREAD = 6; +constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; +constexpr size_t MAX_VOCABULARY_SIZE = 1e10; +constexpr int SSD_SIZE_INDEX = 2; +constexpr int MAX_FILE_NUM = 1000; +constexpr int EMBEDDING_THREAD_NUM = 2; +constexpr int HOST_TO_PREFILL_RATIO = 10; +// for GLOG +struct GlogConfig { + static bool gStatOn; + static int gGlogLevel; + static string gRankId; +}; + +constexpr int GLOG_MAX_BUF_SIZE = 1024; +constexpr int GLOG_TIME_WIDTH_2 = 2; +constexpr int GLOG_TIME_WIDTH_6 = 6; +constexpr char GLOG_STAT_FLAG[] = "statOn"; + +// unique related config +constexpr int UNIQUE_BUCKET = 6; +constexpr int MIN_UNIQUE_THREAD_NUM = 1; + +// validate file +constexpr long long FILE_MAX_SIZE = 1LL << 40; +constexpr int FILE_MIN_SIZE = 0; +constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; +constexpr size_t MAP_BYTE_SIZE{static_cast(10) * 1024 * 1024 * 1024}; #ifdef GTEST - constexpr int KEY_PROCESS_TIMEOUT = 3; +constexpr int KEY_PROCESS_TIMEOUT = 3; #else - constexpr int KEY_PROCESS_TIMEOUT = 120; +constexpr int KEY_PROCESS_TIMEOUT = 120; #endif - constexpr int GET_BATCH_TIMEOUT = 300; - constexpr int EOS_TIMEOUT = 30; - - constexpr size_t DEFAULT_RANDOM_SEED = 10086; - // constexpr int INVALID_KEY_VALUE = -1; - constexpr int64_t INVALID_KEY_VALUE = -1; - constexpr int ALLTOALLVC_ALIGN = 128; - constexpr int PROFILING_START_BATCH_ID = 100; - constexpr int PROFILING_END_BATCH_ID = 200; - constexpr int MGMT_THREAD_BIND = 48; - constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; - constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; - constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent - - const string COMBINE_HISTORY_NAME = "combine_table_history"; - - using emb_key_t = int64_t; - using freq_num_t = int64_t; - using EmbNameT= std::string; - using KeysT = std::vector; - using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector - using TensorInfoT = std::tuple>>::iterator>; - - namespace HybridOption { - const unsigned int USE_STATIC = 0x001; - const unsigned int USE_HOT = 0x001 << 1; - const unsigned int USE_DYNAMIC_EXPANSION = 0x001 << 2; - }; - - string GetChipName(int devID); - int GetThreadNumEnv(); - - namespace UBSize { - const int ASCEND910_PREMIUM_A = 262144; - const int ASCEND910_PRO_B = 262144; - const int ASCEND910_B2 = 196608; - const int ASCEND910_B1 = 196608; - const int ASCEND910_B3 = 196608; - const int ASCEND910_B4 = 196608; - const int ASCEND910_C1 = 196608; - const int ASCEND910_C2 = 196608; - const int ASCEND910_C3 = 196608; - const int ASCEND920_A = 196608; - const int ASCEND910_PRO_A = 262144; - const int ASCEND910_B = 262144; - const int ASCEND910_A = 262144; - const int ASCEND910_B2C = 196608; - }; - - inline int GetUBSize(int devID) - { - const std::map chipUbSizeList = {{"910A", UBSize::ASCEND910_A}, - {"910B", UBSize::ASCEND910_B}, - {"920A", UBSize::ASCEND920_A}, - {"910B1", UBSize::ASCEND910_B1}, - {"910B2", UBSize::ASCEND910_B2}, - {"910B3", UBSize::ASCEND910_B3}, - {"910B4", UBSize::ASCEND910_B4}, - {"910B2C", UBSize::ASCEND910_B2C}}; - auto it = chipUbSizeList.find(GetChipName(devID)); - if (it != chipUbSizeList.end()) { - return it->second; - } - - throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); +constexpr int GET_BATCH_TIMEOUT = 300; +constexpr int EOS_TIMEOUT = 30; + +constexpr size_t DEFAULT_RANDOM_SEED = 10086; +constexpr int64_t INVALID_KEY_VALUE = -1; +constexpr int32_t INVALID_INDEX_VALUE = -1; +constexpr int ALLTOALLVC_ALIGN = 128; +constexpr int PROFILING_START_BATCH_ID = 100; +constexpr int PROFILING_END_BATCH_ID = 200; +constexpr int MGMT_THREAD_BIND = 48; +constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; +constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; +constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent + +const string COMBINE_HISTORY_NAME = "combine_table_history"; +const string SAVE_SPARSE_PATH_PREFIX = "sparse"; + +using emb_key_t = int64_t; +using emb_cache_key_t = uint64_t; +using freq_num_t = int64_t; +using EmbNameT = std::string; +using KeysT = std::vector; +using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector +using UinqueKeyT = std::tuple>; +using RestoreVecSecT = std::tuple>; +using TensorInfoT = std::tuple>>::iterator>; + +namespace HybridOption { +const unsigned int USE_STATIC = 0x001; +const unsigned int USE_DYNAMIC_EXPANSION = 0x001 << 1; +const unsigned int USE_SUM_SAME_ID_GRADIENTS = 0x001 << 2; +}; // namespace HybridOption + +string GetChipName(int devID); +int GetThreadNumEnv(); + +namespace UBSize { +const int ASCEND910_PREMIUM_A = 262144; +const int ASCEND910_PRO_B = 262144; +const int ASCEND910_B2 = 196608; +const int ASCEND910_B1 = 196608; +const int ASCEND910_B3 = 196608; +const int ASCEND910_B4 = 196608; +const int ASCEND910_C1 = 196608; +const int ASCEND910_C2 = 196608; +const int ASCEND910_C3 = 196608; +const int ASCEND920_A = 196608; +const int ASCEND910_PRO_A = 262144; +const int ASCEND910_B = 262144; +const int ASCEND910_A = 262144; +const int ASCEND910_B2C = 196608; +}; // namespace UBSize + +inline int GetUBSize(int devID) +{ + const std::map chipUbSizeList = { + {"910A", UBSize::ASCEND910_A}, {"910B", UBSize::ASCEND910_B}, {"920A", UBSize::ASCEND920_A}, + {"910B1", UBSize::ASCEND910_B1}, {"910B2", UBSize::ASCEND910_B2}, {"910B3", UBSize::ASCEND910_B3}, + {"910B4", UBSize::ASCEND910_B4}, {"910B2C", UBSize::ASCEND910_B2C}, {"910C1", UBSize::ASCEND910_C1}, + {"910C2", UBSize::ASCEND910_C1}, {"910C3", UBSize::ASCEND910_C3}}; + auto it = chipUbSizeList.find(GetChipName(devID)); + if (it != chipUbSizeList.end()) { + return it->second; } - template - struct Batch { - size_t Size() const - { - return sample.size(); - } - - std::string UnParse() const - { - std::string s; - constexpr size_t maxDispLen = 20; - int maxLen = static_cast(std::min(sample.size(), maxDispLen)); - for (int i = 0; i < maxLen; i++) { - s += std::to_string(sample[i]) + " "; - } - return s; - } - - std::vector sample; - std::string name; - size_t batchSize; - int batchId; - int channel = 0; - time_t timestamp { -1 }; - }; - - struct BatchTask { - vector splits; - vector embNames; - size_t batchSize; - int batchQueueId; - int batchId; - int channelId; - time_t timestamp { -1 }; - const void *tensor; - }; - - using EmbBatchT = Batch; - using BatchTaskT = BatchTask; - - struct DDRParam { - vector tmpDataOut; - vector offsetsOut; - DDRParam(vector tmpData, vector offset) - { - tmpDataOut = tmpData; - offsetsOut = offset; - } - }; - - struct RankInfo { - RankInfo() = default; - - RankInfo(int rankId, int deviceId, int localRankSize, int option, const std::vector& ctrlSteps); - RankInfo(int localRankSize, int option, const std::vector& maxStep); - - int rankId {}; - int deviceId {}; - int rankSize {}; - int localRankId {}; - int localRankSize {}; - bool useStatic { false }; - bool useHot {}; - uint32_t option {}; - int nBatch {}; - bool isDDR { false }; - bool isSSDEnabled { false }; - bool useDynamicExpansion {false}; - std::vector ctrlSteps; // 包含三个步数: train_steps, eval_steps, save_steps - }; - - enum TensorIndex : uint32_t { - TENSOR_INDEX_0, - TENSOR_INDEX_1, - TENSOR_INDEX_2, - TENSOR_INDEX_3, - TENSOR_INDEX_4, - TENSOR_INDEX_5, - TENSOR_INDEX_6, - TENSOR_INDEX_7, - TENSOR_INDEX_8 - }; - - enum TupleIndex : uint32_t { - TUPLE_INDEX_0 = 0, - TUPLE_INDEX_1, - TUPLE_INDEX_2, - TUPLE_INDEX_3, - TUPLE_INDEX_4, - TUPLE_INDEX_5, - TUPLE_INDEX_6, - TUPLE_INDEX_7 - }; - - struct RandomInfo { - RandomInfo() = default; - - RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax); - - int start; - int len; - float constantVal; - float randomMin; - float randomMax; - }; - - struct EmbeddingSizeInfo { - EmbeddingSizeInfo() = default; - EmbeddingSizeInfo(size_t embSize, size_t extendSize) - { - embeddingSize = embSize; - extendEmbSize = extendSize; - } - - size_t embeddingSize; - size_t extendEmbSize; - }; - - struct OptimizerInfo { - OptimizerInfo() = default; - OptimizerInfo(std::string name, vector params) - { - optimName = name; - optimParams = std::move(params); - } - - std::string optimName; - vector optimParams; - }; - - struct ThresholdValue { - ThresholdValue() = default; - ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef, bool isSum) - { - tableName = name; - countThreshold = countThre; - timeThreshold = timeThre; - faaeCoefficient = faaeCoef; - isEnableSum = isSum; - } + throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); +} - EmbNameT tableName { "" }; // embName - int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 - int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 - int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 - bool isEnableSum {true}; // 配置false,该表在准入时,count计数不会累加 - }; - - struct FeatureItemInfo { - FeatureItemInfo() = default; - FeatureItemInfo(uint32_t cnt, time_t lastT) - : count(cnt), lastTime(lastT) - {} - - uint32_t count { 0 }; - time_t lastTime { 0 }; - }; - - using HistoryRecords = absl::flat_hash_map>; - struct AdmitAndEvictData { - HistoryRecords historyRecords; // embName ---> {id, FeatureItemInfo} 映射 - absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 - }; - - void SetLog(int rank); - - template - string StringFormat(const string& format, Args ... args) +template +struct Batch { + size_t Size() const { - auto size = static_cast(GLOG_MAX_BUF_SIZE); - auto buf = std::make_unique(size); - memset_s(buf.get(), size, 0, size); - int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args ...); - if (nChar == -1) { - throw invalid_argument("StringFormat failed"); - } - return string(buf.get(), buf.get() + nChar); + return sample.size(); } - // use environment variable GLOG_v to decide if showing debug log. - // default 0, debug message will not display. - // 1 for debug, 2 for trace - constexpr int GLOG_DEBUG = 1; - constexpr int GLOG_TRACE = 2; - - template - std::string VectorToString(const std::vector& vec) + std::string UnParse() const { - std::stringstream ss; - ss << "["; - for (size_t i = 0; i < vec.size(); ++i) { - ss << vec[i]; - if (i != vec.size() - 1) { - ss << ", "; - } + std::string s; + constexpr size_t maxDispLen = 20; + int maxLen = static_cast(std::min(sample.size(), maxDispLen)); + for (int i = 0; i < maxLen; i++) { + s += std::to_string(sample[i]) + " "; } - ss << "]"; - return ss.str(); + return s; } - template - std::string MapToString(const std::map& map) + std::vector sample; + std::string name; + size_t batchSize; + int batchId; + int channel = 0; + time_t timestamp{-1}; +}; + +struct BatchTask { + vector splits; + vector embNames; + size_t batchSize; + int batchQueueId; + int batchId; + int channelId; + time_t timestamp{-1}; + const void* tensor; +}; + +using EmbBatchT = Batch; +using BatchTaskT = BatchTask; + +struct DDRParam { + vector tmpDataOut; + vector offsetsOut; + DDRParam(vector tmpData, vector offset) { - std::stringstream ss; - ss << "{"; - for (auto it = map.begin(); it != map.end(); ++it) { - ss << it->first << ": " << it->second; - if (std::next(it) != map.end()) { - ss << ", "; - } - } - ss << "}"; - return ss.str(); + tmpDataOut = tmpData; + offsetsOut = offset; } - - template - std::string MapToString(const absl::flat_hash_map& map) +}; + +struct RankInfo { + RankInfo() = default; + + RankInfo(int rankId, int deviceId, int localRankSize, int option, const std::vector& ctrlSteps); + RankInfo(int localRankSize, int option, const std::vector& maxStep); + + int rankId{}; + int deviceId{}; + int rankSize{}; + int localRankId{}; + int localRankSize{}; + bool useStatic{false}; + uint32_t option{}; + bool isDDR{false}; + bool isSSDEnabled{false}; + bool useDynamicExpansion{false}; + bool useSumSameIdGradients{true}; + std::vector ctrlSteps; // 包含4个步数: train_steps, eval_steps, save_steps, max_train_steps +}; + +struct EmbBaseInfo { + int batchId; + int channelId; + string name; +}; + +enum TensorIndex : uint32_t { + TENSOR_INDEX_0, + TENSOR_INDEX_1, + TENSOR_INDEX_2, + TENSOR_INDEX_3, + TENSOR_INDEX_4, + TENSOR_INDEX_5, + TENSOR_INDEX_6, + TENSOR_INDEX_7, + TENSOR_INDEX_8 +}; + +enum TupleIndex : uint32_t { + TUPLE_INDEX_0 = 0, + TUPLE_INDEX_1, + TUPLE_INDEX_2, + TUPLE_INDEX_3, + TUPLE_INDEX_4, + TUPLE_INDEX_5, + TUPLE_INDEX_6, + TUPLE_INDEX_7 +}; + +struct RandomInfo { + RandomInfo() = default; + + RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax); + + int start; + int len; + float constantVal; + float randomMin; + float randomMax; +}; + +struct EmbeddingSizeInfo { + size_t embeddingSize = 0; + size_t extendEmbSize = 0; + EmbeddingSizeInfo() = default; + EmbeddingSizeInfo(size_t embSize, size_t extendSize) : embeddingSize(embSize), extendEmbSize(extendSize) {} +}; + +struct OptimizerInfo { + OptimizerInfo() = default; + OptimizerInfo(std::string name, vector params) { - std::stringstream ss; - ss << "{"; - for (auto it = map.begin(); it != map.end(); ++it) { - ss << it->first << ": " << it->second; - if (std::next(it) != map.end()) { - ss << ", "; - } - } - ss << "}"; - return ss.str(); + optimName = name; + optimParams = std::move(params); } - void ValidateReadFile(const string& dataDir, size_t datasetSize); + std::string optimName; + vector optimParams; +}; - template - inline Tensor Vec2TensorI32(const std::vector& data) +struct ThresholdValue { + ThresholdValue() = default; + ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef, bool isSum) { - Tensor tmpTensor(tensorflow::DT_INT32, { static_cast(data.size()) }); - auto tmpData = tmpTensor.flat(); - for (int j = 0; j < static_cast(data.size()); j++) { - tmpData(j) = static_cast(data[j]); - } - return tmpTensor; + tableName = name; + countThreshold = countThre; + timeThreshold = timeThre; + faaeCoefficient = faaeCoef; + isEnableSum = isSum; } - template - inline Tensor Vec2TensorI64(const std::vector& data) - { - Tensor tmpTensor(tensorflow::DT_INT64, { static_cast(data.size()) }); - auto tmpData = tmpTensor.flat(); - for (int j = 0; j < static_cast(data.size()); j++) { - tmpData(j) = static_cast(data[j]); + EmbNameT tableName{""}; // embName + int countThreshold{ + -1}; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 + int timeThreshold{-1}; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 + int faaeCoefficient{1}; // 配置后,该表在准入时,count计数会乘以该系数 + bool isEnableSum{true}; // 配置false,该表在准入时,count计数不会累加 +}; + +struct FeatureItemInfo { + FeatureItemInfo() = default; + FeatureItemInfo(uint32_t cnt, time_t lastT) : count(cnt), lastTime(lastT) {} + + uint32_t count{0}; + time_t lastTime{0}; +}; + +using HistoryRecords = absl::flat_hash_map>; +struct AdmitAndEvictData { + HistoryRecords historyRecords; // embName ---> {id, FeatureItemInfo} 映射 + absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 +}; + +void SetLog(int rank); + +template +string StringFormat(const string& format, Args... args) +{ + auto size = static_cast(GLOG_MAX_BUF_SIZE); + auto buf = std::make_unique(size); + memset_s(buf.get(), size, 0, size); + int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args...); + if (nChar == -1) { + throw invalid_argument("StringFormat failed"); + } + return string(buf.get(), buf.get() + nChar); +} + +// use environment variable GLOG_v to decide if showing debug log. +// default 0, debug message will not display. +// 1 for debug, 2 for trace +constexpr int GLOG_DEBUG = 1; +constexpr int GLOG_TRACE = 2; + +template +std::string VectorToString(const std::vector& vec) +{ + constexpr size_t maxDispLen = 20; // max display number + int maxLen = static_cast(std::min(vec.size(), maxDispLen)); + + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < maxLen; ++i) { + ss << vec[i]; + if (i != vec.size() - 1) { + ss << ", "; } - return tmpTensor; } - - struct EmbInfoParams { - EmbInfoParams() = default; - - EmbInfoParams(const std::string& name, - int sendCount, - int embeddingSize, - int extEmbeddingSize, - bool isSave, - bool isGrad) - : name(name), - sendCount(sendCount), - embeddingSize(embeddingSize), - extEmbeddingSize(extEmbeddingSize), - isSave(isSave), - isGrad(isGrad) - { + ss << "]"; + return ss.str(); +} + +std::string FloatPtrToLimitStr(float* ptr, const size_t& prtSize); + +template +std::string MapToString(const std::map& map) +{ + std::stringstream ss; + ss << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + ss << it->first << ": " << it->second; + if (std::next(it) != map.end()) { + ss << ", "; } - std::string name; - int sendCount; - int embeddingSize; - int extEmbeddingSize; - bool isSave; - bool isGrad; - }; - - struct EmbInfo { - EmbInfo() = default; - - EmbInfo(const EmbInfoParams& embInfoParams, - std::vector vocabsize, - std::vector initializeInfos, - std::vector ssdDataPath) - : name(embInfoParams.name), - sendCount(embInfoParams.sendCount), - embeddingSize(embInfoParams.embeddingSize), - extEmbeddingSize(embInfoParams.extEmbeddingSize), - isSave(embInfoParams.isSave), - isGrad(embInfoParams.isGrad), - devVocabSize(vocabsize[0]), - hostVocabSize(vocabsize[1]), - ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), - initializeInfos(initializeInfos), - ssdDataPath(std::move(ssdDataPath)) - { + } + ss << "}"; + return ss.str(); +} + +template +std::string MapToString(const absl::flat_hash_map& map) +{ + std::stringstream ss; + ss << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + ss << it->first << ": " << it->second; + if (std::next(it) != map.end()) { + ss << ", "; } + } + ss << "}"; + return ss.str(); +} + +void ValidateReadFile(const string& dataDir, size_t datasetSize); + +template +inline Tensor Vec2TensorI32(const std::vector& data) +{ + Tensor tmpTensor(tensorflow::DT_INT32, {static_cast(data.size())}); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < static_cast(data.size()); j++) { + tmpData(j) = static_cast(data[j]); + } + return tmpTensor; +} + +template +inline Tensor Vec2TensorI64(const std::vector& data) +{ + Tensor tmpTensor(tensorflow::DT_INT64, {static_cast(data.size())}); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < static_cast(data.size()); j++) { + tmpData(j) = static_cast(data[j]); + } + return tmpTensor; +} + +struct EmbInfoParams { + std::string name; + int sendCount; + int embeddingSize; + int extEmbeddingSize; + bool isSave; + bool isGrad; + EmbInfoParams() = default; + + EmbInfoParams(const std::string& name, int sendCount, int embeddingSize, int extEmbeddingSize, bool isSave, + bool isGrad) + : name(name), + sendCount(sendCount), + embeddingSize(embeddingSize), + extEmbeddingSize(extEmbeddingSize), + isSave(isSave), + isGrad(isGrad) + { + } +}; + +struct EmbInfo { + EmbInfo() = default; + + EmbInfo(const EmbInfoParams& embInfoParams, std::vector vocabsize, + std::vector initializeInfos, std::vector ssdDataPath) + : name(embInfoParams.name), + sendCount(embInfoParams.sendCount), + embeddingSize(embInfoParams.embeddingSize), + extEmbeddingSize(embInfoParams.extEmbeddingSize), + isSave(embInfoParams.isSave), + isGrad(embInfoParams.isGrad), + devVocabSize(vocabsize[0]), + hostVocabSize(vocabsize[1]), + ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), + initializeInfos(std::move(initializeInfos)), + ssdDataPath(std::move(ssdDataPath)) + { + } + + std::string name; + int sendCount; + int embeddingSize; + int extEmbeddingSize; + bool isSave; + bool isGrad; + size_t devVocabSize; + size_t hostVocabSize; + size_t ssdVocabSize; + std::vector initializeInfos; + std::vector ssdDataPath; +}; + +struct HostEmbTable { + EmbInfo hostEmbInfo; + std::vector> embData; +}; + +struct All2AllInfo { + KeysT keyRecv; + vector scAll; + vector countRecv; + All2AllInfo() = default; + All2AllInfo(KeysT keyRecv, vector scAll, vector countRecv) + : keyRecv(keyRecv), + scAll(scAll), + countRecv(countRecv) + { + } +}; + +struct UniqueInfo { + vector restore; + vector hotPos; + All2AllInfo all2AllInfo; + UniqueInfo() = default; + UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo) + : restore(restore), + hotPos(hotPos), + all2AllInfo(all2AllInfo) + { + } +}; + +struct KeySendInfo { + KeysT keySend; + vector keyCount; +}; + +using EmbMemT = absl::flat_hash_map; +using OffsetMemT = std::map; +using KeyOffsetMemT = std::map>; +using KeyCountMemT = std::map>; +using Table2ThreshMemT = absl::flat_hash_map; +using trans_serialize_t = uint8_t; +using OffsetMapT = std::map>; +using OffsetT = std::vector; +using AllKeyOffsetMapT = std::map>; +using KeyFreqMemT = unordered_map>; +using EmbLocalTableT = EmbCache::EmbCacheManager; + +enum class CkptFeatureType { + HOST_EMB = 0, + EMB_HASHMAP = 1, + MAX_OFFSET = 2, + KEY_OFFSET_MAP = 3, + FEAT_ADMIT_N_EVICT = 4, + DDR_KEY_FREQ_MAP = 5, + EXCLUDE_DDR_KEY_FREQ_MAP = 6, + KEY_COUNT_MAP = 7, + EMB_LOCAL_TABLE = 8 +}; + +struct CkptData { + EmbMemT* hostEmbs = nullptr; + OffsetMemT maxOffset; + KeyOffsetMemT keyOffsetMap; + OffsetMapT offsetMap; + OffsetMapT* offsetMapPtr = &offsetMap; + KeyCountMemT keyCountMap; + Table2ThreshMemT table2Thresh; + AdmitAndEvictData histRec; + KeyFreqMemT ddrKeyFreqMaps; + KeyFreqMemT excludeDDRKeyFreqMaps; +}; + +struct CkptTransData { + std::vector int64Arr; + std::vector addressArr; + std::vector int32Arr; + std::vector transDataset; // may all use this to transfer data + std::vector attribute; // may need to use other form for attributes + size_t datasetSize; + size_t attributeSize; +}; + +enum class CkptDataType { + EMB_INFO = 0, + EMB_DATA = 1, + EMB_HASHMAP = 2, + DEV_OFFSET = 3, + EMB_CURR_STAT = 4, + NDDR_OFFSET = 5, + NDDR_FEATMAP = 6, + TABLE_2_THRESH = 7, + HIST_REC = 8, + ATTRIBUTE = 9, + DDR_FREQ_MAP = 10, + EXCLUDE_FREQ_MAP = 11, + EVICT_POS = 12, + KEY_COUNT_MAP = 13 +}; + +static std::string CkptDataTypeName(CkptDataType type) +{ + switch (type) { + case CkptDataType::EMB_INFO: + return "EMB_INFO"; + case CkptDataType::EMB_DATA: + return "EMB_DATA"; + case CkptDataType::EMB_HASHMAP: + return "EMB_HASHMAP"; + case CkptDataType::DEV_OFFSET: + return "DEV_OFFSET"; + case CkptDataType::EMB_CURR_STAT: + return "EMB_CURR_STAT"; + case CkptDataType::NDDR_OFFSET: + return "NDDR_OFFSET"; + case CkptDataType::NDDR_FEATMAP: + return "NDDR_FEATMAP"; + case CkptDataType::TABLE_2_THRESH: + return "TABLE_2_THRESH"; + case CkptDataType::HIST_REC: + return "HIST_REC"; + case CkptDataType::ATTRIBUTE: + return "ATTRIBUTE"; + case CkptDataType::DDR_FREQ_MAP: + return "DDR_FREQ_MAP"; + case CkptDataType::EXCLUDE_FREQ_MAP: + return "EXCLUDE_FREQ_MAP"; + case CkptDataType::EVICT_POS: + return "EVICT_POS"; + case CkptDataType::KEY_COUNT_MAP: + return "KEY_COUNT_MAP"; + default: + return "UNKNOWN"; + } +} + +enum CTRLogLevel { // can't use enum class due to compatibility for AccCTR + DEBUG = 0, + INFO, + WARN, + ERROR, +}; + +static void CTRLog(int level, const char* msg) +{ + switch (level) { + case CTRLogLevel::DEBUG: + LOG_DEBUG(msg); + break; + case CTRLogLevel::INFO: + LOG_INFO(msg); + break; + case CTRLogLevel::WARN: + LOG_WARN(msg); + break; + case CTRLogLevel::ERROR: + LOG_ERROR(msg); + break; + default: + break; + } +} + +ostream& operator<<(ostream& ss, MxRec::CkptDataType type); +bool CheckFilePermission(const string& filePath); - std::string name; - int sendCount; - int embeddingSize; - int extEmbeddingSize; - bool isSave; - bool isGrad; - size_t devVocabSize; - size_t hostVocabSize; - size_t ssdVocabSize; - std::vector initializeInfos; - std::vector ssdDataPath; - }; - - struct HostEmbTable { - EmbInfo hostEmbInfo; - std::vector> embData; - }; - - struct EmbHashMapInfo { - absl::flat_hash_map hostHashMap; // key在HBM中的偏移 - std::vector devOffset2Batch; // has -1 - std::vector devOffset2Key; - size_t currentUpdatePos; - size_t currentUpdatePosStart; - size_t hostVocabSize; - size_t devVocabSize; - size_t freeSize; - std::vector lookUpVec; - std::vector missingKeysHostPos; // 用于记录当前batch在host上需要换出的偏移 - std::vector swapPos; // 记录从HBM换出到DDR的offset - /* - * 取值范围:[0,devVocabSize+hostVocabSize); - * [0,devVocabSize-1]时存储在HBM, [devVocabSize,devVocabSize+hostVocabSize)存储在DDR - */ - size_t maxOffset { 0 }; - /* - * 记录DDR内淘汰列表,其值为相对HBM+DDR大表的;hostHashMap可直接使用;操作ddr内emb时需减掉devVocabSize - * 例如:HBM表大小20(offset:0~19),DDR表大小为100(offset:0~99); - * 若DDR内0位置被淘汰,记录到evictPos的值为0+20=20 - */ - std::vector evictPos; - std::vector evictDevPos; // 记录HBM内淘汰列表 - size_t maxOffsetOld { 0 }; - std::vector evictPosChange; - std::vector evictDevPosChange; - std::vector> devOffset2KeyOld; - std::vector> oldSwap; // (old on dev, old on host) - /* - * HBM与DDR换入换出时,已存在于DDR且要转移到HBM的key(不包含新key); 用于SSD模式 - * (区别于oldSwap: pair.second为已存在于DDR key + 换入换出前映射到DDR的新key) - */ - std::vector ddr2HbmKeys; - void SetStartCount(); - - bool HasFree(size_t i) const; - }; - - struct All2AllInfo { - KeysT keyRecv; - vector scAll; - vector countRecv; - All2AllInfo() = default; - All2AllInfo(KeysT keyRecv, vector scAll, vector countRecv) - : keyRecv(keyRecv), scAll(scAll), countRecv(countRecv) {} - }; - - struct UniqueInfo { - vector restore; - vector hotPos; - All2AllInfo all2AllInfo; - UniqueInfo() = default; - UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo) - : restore(restore), hotPos(hotPos), all2AllInfo(all2AllInfo) {} - }; - - struct KeySendInfo { - KeysT keySend; - vector keyCount; - }; - - using EmbMemT = absl::flat_hash_map; - using EmbHashMemT = absl::flat_hash_map; - using OffsetMemT = std::map; - using KeyOffsetMemT = std::map>; - using KeyCountMemT = std::map>; - using Table2ThreshMemT = absl::flat_hash_map; - using trans_serialize_t = uint8_t; - using OffsetMapT = std::map>; - using OffsetT = std::vector; - using AllKeyOffsetMapT = std::map>; - using KeyFreqMemT = unordered_map>; - - enum class CkptFeatureType { - HOST_EMB = 0, - EMB_HASHMAP = 1, - MAX_OFFSET = 2, - KEY_OFFSET_MAP = 3, - FEAT_ADMIT_N_EVICT = 4, - DDR_KEY_FREQ_MAP = 5, - EXCLUDE_DDR_KEY_FREQ_MAP = 6, - KEY_COUNT_MAP = 7 - }; - - struct CkptData { - EmbMemT* hostEmbs = nullptr; - EmbHashMemT embHashMaps; - OffsetMemT maxOffset; - KeyOffsetMemT keyOffsetMap; - OffsetMapT offsetMap; - OffsetMapT* offsetMapPtr = &offsetMap; - KeyCountMemT keyCountMap; - Table2ThreshMemT table2Thresh; - AdmitAndEvictData histRec; - KeyFreqMemT ddrKeyFreqMaps; - KeyFreqMemT excludeDDRKeyFreqMaps; - }; - - struct CkptTransData { - std::vector int64Arr; - std::vector addressArr; - std::vector floatArr; - std::vector int32Arr; - std::vector transDataset; // may all use this to transfer data - std::vector attribute; // may need to use other form for attributes - size_t datasetSize; - size_t attributeSize; - }; - - enum class CkptDataType { - EMB_INFO = 0, - EMB_DATA = 1, - EMB_HASHMAP = 2, - DEV_OFFSET = 3, - EMB_CURR_STAT = 4, - NDDR_OFFSET = 5, - NDDR_FEATMAP = 6, - TABLE_2_THRESH = 7, - HIST_REC = 8, - ATTRIBUTE = 9, - DDR_FREQ_MAP = 10, - EXCLUDE_FREQ_MAP = 11, - EVICT_POS = 12, - KEY_COUNT_MAP = 13 - }; - - ostream& operator<<(ostream& ss, MxRec::CkptDataType type); - bool CheckFilePermission(const string& filePath); -} // end namespace MxRec +int GetStepFromPath(const string& loadPath); +} // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " #define STAT_INFO "[StatInfo] " #ifdef GTEST - #define GTEST_PRIVATE public +#define GTEST_PRIVATE public #else - #define GTEST_PRIVATE private +#define GTEST_PRIVATE private #endif #endif diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp index 9cfec7393695a426f6c6077389c303ba705f52e4..57478553c89fe51b4726e6cf120f6e3456b38b53 100644 --- a/src/core/utils/config.cpp +++ b/src/core/utils/config.cpp @@ -20,13 +20,7 @@ See the License for the specific language governing permissions and using namespace std; namespace MxRec { - namespace ApplyGradientsStrategyOptions { - const std::string DIRECT_APPLY = "direct_apply"; - const std::string SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply"; - }; - // 设置环境变量默认值 - string GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY; int GlobalEnv::aclTimeout = -1; // 默认阻塞方式,一直等待直到数据接收完成。 int GlobalEnv::hdChannelSize = 40; // 默认通道深度40 int GlobalEnv::keyProcessThreadNum = 6; // 默认6个线程 @@ -42,12 +36,6 @@ namespace MxRec { /// 配置环境变量,Python侧已经做了变量值校验,CPP侧直接使用即可;bool类型,1代表true,0代表false void ConfigGlobalEnv() { - // 设置梯度策略 - const char *envStrategy = getenv(RecEnvNames::APPLY_GRADIENTS_STRATEGY); - if (envStrategy != nullptr) { - GlobalEnv::applyGradientsStrategy = envStrategy; - } - // 设置ACL超时时间 const char *envAclTimeout = getenv(RecEnvNames::ACL_TIMEOUT); if (envAclTimeout != nullptr) { @@ -117,9 +105,8 @@ namespace MxRec { void LogGlobalEnv() { - LOG_DEBUG("Environment variables are: [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], " + LOG_DEBUG("Environment variables are: [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], " "[{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}]", - RecEnvNames::APPLY_GRADIENTS_STRATEGY, GlobalEnv::applyGradientsStrategy, RecEnvNames::ACL_TIMEOUT, GlobalEnv::aclTimeout, RecEnvNames::HD_CHANNEL_SIZE, GlobalEnv::hdChannelSize, RecEnvNames::KEY_PROCESS_THREAD_NUM, GlobalEnv::keyProcessThreadNum, diff --git a/src/core/utils/config.h b/src/core/utils/config.h index 4c56c0d428936f12b5a2af601741a52a9c2dc9ad..fc5536f6920a499080e19fc2efd7a123268c34a7 100644 --- a/src/core/utils/config.h +++ b/src/core/utils/config.h @@ -16,11 +16,8 @@ See the License for the specific language governing permissions and #ifndef MXREC_CONFIG_H #define MXREC_CONFIG_H -#include - namespace MxRec { namespace RecEnvNames { - const char *const APPLY_GRADIENTS_STRATEGY = "APPLY_GRADIENTS_STRATEGY"; const char *const ACL_TIMEOUT = "AclTimeout"; const char *const HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE"; const char *const KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM"; @@ -34,13 +31,7 @@ namespace MxRec { const char *const RECORD_KEY_COUNT = "RECORD_KEY_COUNT"; }; - namespace ApplyGradientsStrategyOptions { - extern const std::string DIRECT_APPLY; - extern const std::string SUM_SAME_ID_GRADIENTS_AND_APPLY; - }; - struct GlobalEnv { - static std::string applyGradientsStrategy; static int aclTimeout; static int hdChannelSize; static int keyProcessThreadNum; diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..a42e514776789dc747868483a03fa6790228310a --- /dev/null +++ b/src/core/utils/task_queue.h @@ -0,0 +1,110 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef TASK_QUEUE_H +#define TASK_QUEUE_H + +#include +#include +#include +#include + +namespace MxRec { + namespace Common { + template + class TaskQueue { + public: + TaskQueue() = default; + + ~TaskQueue() = default; + + TaskQueue(TaskQueue const &other) + { + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + } + + TaskQueue &operator=(TaskQueue const &other) + { + if (this == &other) { + return *this; + } + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + return *this; + } + + void Pushv(T &t) + { + std::lock_guard lk(mut); + dataQueue.push_back(std::move(t)); + dataCond.notify_one(); + } + + void Pushv(T &&t) + { + std::lock_guard lk(mut); + dataQueue.emplace_back(t); + dataCond.notify_one(); + } + + T WaitAndPop() + { + std::unique_lock lk(mut); + dataCond.wait(lk, [this] { + if (!finished) { + return !dataQueue.empty(); + } else { + return true; + } + }); + T res; + if (finished) { + return std::move(res); + } + res = std::move(dataQueue.front()); + dataQueue.pop_front(); + return std::move(res); + } + + void DestroyQueue() + { + finished = true; + dataCond.notify_one(); + } + + bool Empty() const + { + std::lock_guard lk(mut); + return dataQueue.empty(); + } + + size_t Size() const + { + std::lock_guard lk(mut); + return dataQueue.size(); + } + + private: + mutable std::mutex mut; + std::list dataQueue; + std::condition_variable dataCond; + bool finished = false; + }; + } +} + + +#endif diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 85b8e1d09516e35920c804d2727d9955ca16c289..afc3fe3ad8696c9acea89c8e70ed3d270d4710f9 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -74,15 +74,15 @@ int CheckCommFinished(MPI_Request& req, int channelId) // 表示数据集的不可变性定义,这个类的 MakeIterator() 方法告诉 TensorFlow 怎样在数据集上生成迭代器对象。 class EosDatasetOp::Dataset : public DatasetBase { public: - explicit Dataset(OpKernelContext *ctx, const DatasetBase *input, int32_t channelId, int32_t maxTrainSteps, - int32_t maxEvalSteps) + explicit Dataset(OpKernelContext *ctx, const DatasetBase *input, int32_t channelId, + int32_t maxTrainSteps, + int32_t maxEvalSteps) : DatasetBase(DatasetContext(ctx)), input_(input), channelId_(channelId), maxTrainSteps_(maxTrainSteps), maxEvalSteps_(maxEvalSteps), - id_(g_datasetId[channelId]) - { + id_(g_datasetId[channelId]) { input_->Ref(); auto os_input = input->output_shapes(); output_shapes_ = os_input; @@ -93,12 +93,13 @@ public: MPI_Comm_size(g_comm[channelId], &g_rankSize); LOG_DEBUG("EosDataset: {} was born for channel: {}, maxTrainSteps: {}, maxEvalSteps: {}.", - g_datasetId[channelId], channelId, maxTrainSteps, maxEvalSteps); + g_datasetId[channelId], channelId, maxTrainSteps, maxEvalSteps); g_datasetId[channelId] += 1; } - Dataset(const Dataset&) = delete; - Dataset& operator=(const Dataset&) = delete; + Dataset(const Dataset &) = delete; + + Dataset &operator=(const Dataset &) = delete; ~Dataset() override { @@ -147,8 +148,10 @@ public: } protected: - Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override - { + Status + AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, + Node **output) const override + { Node *input_graph = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); Node *channel_id_x = nullptr; @@ -158,7 +161,8 @@ protected: Node *max_eval_steps_x = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(maxEvalSteps_, &max_eval_steps_x)); TF_RETURN_IF_ERROR( - b->AddDataset(this, { input_graph, channel_id_x, max_train_steps_x, max_eval_steps_x }, output)); + b->AddDataset(this, {input_graph, channel_id_x, max_train_steps_x, max_eval_steps_x}, + output)); return Status::OK(); } @@ -166,20 +170,27 @@ private: // 表示特定数据集上的迭代器的可变性,这个类的 GetNextInternal() 方法告诉 TensorFlow 怎样获取迭代器的下一个元素。 class Iterator : public DatasetIterator { public: - explicit Iterator(const Params ¶ms) : DatasetIterator(params), i_(0), iter_times_(0) {} + explicit Iterator(const Params ¶ms) : DatasetIterator(params), i_(0), + iter_times_(0) {} + #if defined(TF_VERSION_TF2) Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } #else + Status Initialize(IteratorContext *ctx) override { return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } + #endif - Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override - { + + Status + GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, + bool *end_of_sequence) override + { mutex_lock l(mu_); if (!input_impl_) { *end_of_sequence = true; @@ -202,12 +213,14 @@ private: getNextStatus = GET_NEXT_TERMINATE; MPI_Request req; - MPI_Iallreduce(MPI_IN_PLACE, &getNextStatus, 1, MPI_INT, MPI_SUM, g_comm[channelId], &req); + MPI_Iallreduce(MPI_IN_PLACE, &getNextStatus, 1, MPI_INT, MPI_SUM, g_comm[channelId], + &req); CheckCommFinished(req, channelId); keyProcess->SetEos(1, dataset()->channelId_); - LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, - iter_times_); + LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", + dataset()->channelId_, + iter_times_); input_impl_.reset(); return Status::OK(); @@ -220,7 +233,8 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; keyProcess->SetEos(1, dataset()->channelId_); - LOG_DEBUG("[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", + LOG_DEBUG( + "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); input_impl_.reset(); @@ -232,11 +246,12 @@ private: } protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override - { - return model::MakeKnownRatioNode(std::move(args), /* ratio= */ 1); + std::shared_ptr CreateNode( + IteratorContext *ctx, model::Node::Args args) const override + { + return model::MakeKnownRatioNode(std::move(args), 1); // ratio = 1 } + #if defined(TF_VERSION_TF2) Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { @@ -244,15 +259,18 @@ private: return Status::OK(); } #else - Status SaveInternal(IteratorStateWriter* writer) override + + Status SaveInternal(IteratorStateWriter *writer) override { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); return Status::OK(); } + #endif - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override - { + + Status RestoreInternal(IteratorContext *ctx, + IteratorStateReader *reader) override + { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return Status::OK(); @@ -261,11 +279,14 @@ private: private: static constexpr int GET_NEXT_CONTINUE = 1; static constexpr int GET_NEXT_TERMINATE = 0; - + tensorflow::mutex mu_; - int64 i_ GUARDED_BY(mu_); - int64 iter_times_ GUARDED_BY(mu_); - std::unique_ptr input_impl_ GUARDED_BY(mu_); + int64 i_ + GUARDED_BY(mu_); + int64 iter_times_ + GUARDED_BY(mu_); + std::unique_ptr input_impl_ + GUARDED_BY(mu_); }; const DatasetBase *input_; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index c3687e8a46fd62e6461999d2eb4ece748178cb8d..98fca9615339c4ed3dad1cad1ab37fb9c7153d49 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -403,7 +403,7 @@ namespace MxRec { out(0) = batchId; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - LOG_DEBUG(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); + LOG_DEBUG(StringFormat("skip excess batch after %d/%d", batchId, maxStep)); return; } } @@ -640,4 +640,22 @@ namespace tensorflow { }); REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), MxRec::CustOps); -} \ No newline at end of file + + // ######################## tf注册LazyAdam融合算子同名算子 ######################## + REGISTER_OP("LazyAdam") + .Input("gradient: float32") + .Input("indices: int32") + .Input("input_m: float32") + .Input("input_v: float32") + .Input("input_var: float32") + .Input("lr: float32") + .Attr("beta1: float") + .Attr("beta2: float") + .Attr("epsilon: float") + .Output("output_m: float32") + .Output("output_v: float32") + .Output("output_var: float32") + .SetIsStateful() + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + REGISTER_KERNEL_BUILDER(Name("LazyAdam").Device(DEVICE_CPU), MxRec::CustOps); +} diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 403692fb1c83d4c936e925543c8489b76c6d0c7c..767cf4e0687b7a25aaa6febfcc65de480cfe7a6a 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -57,18 +57,30 @@ namespace { return logicId; } + uint32_t GetDeviceCount() + { + uint32_t count; + aclError ec = aclrtGetDeviceCount(&count); + if (ec != 0) { + throw runtime_error("failed to get device count. "); + } + return count; + } + PYBIND11_MODULE(mxrec_pybind, m) { m.def("get_ub_hot_size", &GetUBHotSize, py::arg("device_id")); m.def("get_logic_id", &GetLogicID, py::arg("physic_id")); - m.attr("USE_STATIC") = py::int_(HybridOption::USE_STATIC); + m.def("get_device_count", &GetDeviceCount); - m.attr("USE_HOT") = py::int_(HybridOption::USE_HOT); + m.attr("USE_STATIC") = py::int_(HybridOption::USE_STATIC); m.attr("USE_DYNAMIC_EXPANSION") = py::int_(HybridOption::USE_DYNAMIC_EXPANSION); + m.attr("USE_SUM_SAME_ID_GRADIENTS") = py::int_(HybridOption::USE_SUM_SAME_ID_GRADIENTS); + GetRankInfo(m); GetEmbInfoParams(m); @@ -126,7 +138,7 @@ namespace { { pybind11::class_(m, "EmbInfo") .def(pybind11::init, - std::vector&, std::vector&>(), + std::vector&, std::vector&>(), py::arg("embInfoParams"), py::arg("vocab_size"), py::arg("initialize_infos"), @@ -164,36 +176,38 @@ namespace { void GetInitializeInfo(pybind11::module_ &m) { - pybind11::class_(m, "InitializeInfo") - .def(py::init(), py::arg("name"), py::arg("start"), - py::arg("len"), py::arg("constant_initializer_info")) - .def(py::init(), py::arg("name"), py::arg("start"), - py::arg("len"), py::arg("normal_initializer_info")) - .def_readwrite("name", &InitializeInfo::name) - .def_readwrite("start", &InitializeInfo::start) - .def_readwrite("len", &InitializeInfo::len) - .def_readwrite("ConstantInitializerInfo", &InitializeInfo::constantInitializerInfo) - .def_readwrite("NormalInitializerInfo", &InitializeInfo::normalInitializerInfo); + pybind11::class_(m, "InitializeInfo") + .def(py::init(), + py::arg("name"), py::arg("start"), py::arg("len"), py::arg("constant_initializer_info")) + .def(py::init(), + py::arg("name"), py::arg("start"), py::arg("len"), py::arg("normal_initializer_info")) + .def_readwrite("name", &EmbCache::InitializerInfo::name) + .def_readwrite("start", &EmbCache::InitializerInfo::start) + .def_readwrite("len", &EmbCache::InitializerInfo::len) + .def_readwrite("ConstantInitializerInfo", &EmbCache::InitializerInfo::constantInitializerInfo) + .def_readwrite("NormalInitializerInfo", &EmbCache::InitializerInfo::normalInitializerInfo); } void GetConstantInitializerInfo(pybind11::module_ &m) { - pybind11::class_(m, "ConstantInitializerInfo") - .def(py::init(), py::arg("constant_val") = 0, py::arg("initK") = 1.0) - .def_readwrite("constant_val", &ConstantInitializerInfo::constantValue) - .def_readwrite("initK", &ConstantInitializerInfo::initK); + pybind11::class_(m, "ConstantInitializerInfo") + .def(py::init(), py::arg("constant_val") = 0, py::arg("initK") = 1.0) + .def_readwrite("constant_val", &EmbCache::ConstantInitializerInfo::constantValue) + .def_readwrite("initK", &EmbCache::ConstantInitializerInfo::initK); } void GetNormalInitializerInfo(pybind11::module_ &m) { - pybind11::class_(m, "NormalInitializerInfo") - .def(py::init(), py::arg("mean") = 0.0, - py::arg("stddev") = 1.0, py::arg("seed") = 0, - py::arg("initK") = 1.0) - .def_readwrite("mean", &NormalInitializerInfo::mean) - .def_readwrite("stddev", &NormalInitializerInfo::stddev) - .def_readwrite("seed", &NormalInitializerInfo::seed) - .def_readwrite("initK", &NormalInitializerInfo::initK); + pybind11::class_(m, "NormalInitializerInfo") + .def(py::init(), + py::arg("mean") = 0.0, + py::arg("stddev") = 1.0, + py::arg("seed") = 0, + py::arg("initK") = 1.0) + .def_readwrite("mean", &EmbCache::NormalInitializerInfo::mean) + .def_readwrite("stddev", &EmbCache::NormalInitializerInfo::stddev) + .def_readwrite("seed", &EmbCache::NormalInitializerInfo::seed) + .def_readwrite("initK", &EmbCache::NormalInitializerInfo::initK); } void GetHybridMgmt(pybind11::module_& m) @@ -204,9 +218,11 @@ namespace { py::arg("seed") = DEFAULT_RANDOM_SEED, py::arg("threshold_values") = vector {}, py::arg("if_load") = false) .def("save", &MxRec::HybridMgmt::Save, py::arg("save_path") = "") - .def("load", &MxRec::HybridMgmt::Load, py::arg("load_path") = "") + .def("load", &MxRec::HybridMgmt::Load, py::arg("load_path") = "", + py::arg("warm_start_tables") = vector {}) .def("destroy", &MxRec::HybridMgmt::Destroy) .def("evict", &MxRec::HybridMgmt::Evict) + .def("fetch_device_emb", &MxRec::HybridMgmt::FetchDeviceEmb) .def("send", &MxRec::HybridMgmt::SendHostMap, py::arg("table_name") = "") .def("send_load_offset", &MxRec::HybridMgmt::SendLoadMap, py::arg("table_name") = "") .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")) diff --git a/src/test_ut.sh b/src/test_ut.sh index 6146aaab84ccbfdde78ec21d6bfe95d6e181d49a..7305c0819ee8c84e75bf17c1902fd3bb27271789 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -129,6 +129,9 @@ mkdir build cd build python_path="$(dirname "$(dirname "$(which python3.7)")")" +# config asan environment variable +export ASAN_OPTIONS=halt_on_error=1:detect_leaks=1:fast_unwind_on_malloc=0 +export LSAN_OPTIONS=suppressions=../tests/leaks.supp cmake -DCMAKE_BUILD_TYPE=Debug \ -DTF_PATH="${python_path}"/lib/python3.7/site-packages/"${TF_DIR}" \ diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index ad7bf34d0d192d36dc0e0f3f7d9bc044e8db9cdc..8d2963636bb816da515db3f974edfc80e43033be 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -143,7 +143,7 @@ protected: } } - void SetDDRKeyFreqMap(unordered_map& testDDRKeyFreqMap) + void SetDDRKeyFreqMap(unordered_map& testDDRKeyFreqMap) { for (int64_t i { 0 }; i < hostVocabSize; ++i) { testDDRKeyFreqMap[featMem] = i; @@ -159,7 +159,7 @@ protected: } } - void SetExcludeDDRKeyFreqMap(unordered_map& testExcludeDDRKeyFreqMap) + void SetExcludeDDRKeyFreqMap(unordered_map& testExcludeDDRKeyFreqMap) { for (int64_t i { 0 }; i < hostVocabSize; ++i) { testExcludeDDRKeyFreqMap[featMem] = i; @@ -169,7 +169,7 @@ protected: void SetDDRKeyFreqMaps(KeyFreqMemT& testDDRKeyFreqMaps) { - unordered_map testDDRKeyFreqMap; + unordered_map testDDRKeyFreqMap; for (const auto& testEmbInfo : testEmbInfos) { SetDDRKeyFreqMap(testDDRKeyFreqMap); testDDRKeyFreqMaps[testEmbInfo.name] = std::move(testDDRKeyFreqMap); @@ -187,7 +187,7 @@ protected: void SetExcludeDDRKeyFreqMaps(KeyFreqMemT& testExcludeDDRKeyFreqMaps) { - unordered_map testExcludeDDRKeyFreqMap; + unordered_map testExcludeDDRKeyFreqMap; for (const auto& testEmbInfo : testEmbInfos) { SetExcludeDDRKeyFreqMap(testExcludeDDRKeyFreqMap); testExcludeDDRKeyFreqMaps[testEmbInfo.name] = std::move(testExcludeDDRKeyFreqMap); diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp deleted file mode 100644 index ac2f1583ed777b71cab6f63426c5252ef06772cd..0000000000000000000000000000000000000000 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include -#include - -#include "emb_hashmap/emb_hashmap.h" -#include "hybrid_mgmt/hybrid_mgmt_block.h" -#include "ssd_cache/cache_manager.h" -#include "utils/common.h" - -using namespace std; -using namespace MxRec; -using namespace testing; - -const int HBM_VOCAB_SIZE = 10; -const int DDR_VOCAB_SIZE = 100; -const int SSD_VOCAB_SIZE = 100; -const int INT_2 = 2; -const int INT_4 = 4; -const int INT_21 = 21; -const int INT_42 = 42; -const int NEGATIVE_INT_1 = -1; - -// 刷新换入换出频次和打印信息 -void RefreshSwapFreqInfoAndPrint(EmbHashMap& hostHashMaps, string embTableName, int opTimes) -{ - auto& embHashMap = hostHashMaps.embHashMaps[embTableName]; - hostHashMaps.RefreshFreqInfoWithSwap(embTableName, embHashMap); - vector hbm2DdrKeyList; - vector ddr2HbmKeyList; - for (auto it : embHashMap.oldSwap) { - hbm2DdrKeyList.emplace_back(it.first); - ddr2HbmKeyList.emplace_back(it.second); - } - LOG_INFO("embHashMap hbm2DdrKeyList: {}", VectorToString(hbm2DdrKeyList)); - LOG_INFO("embHashMap ddr2HbmKeyList: {}", VectorToString(ddr2HbmKeyList)); - embHashMap.oldSwap.clear(); - LOG_INFO("RefreshSwapFreqInfoAndPrint end, opTimes:{}", opTimes); -} - -vector GetEmbInfoList() -{ - EmbInfo embInfo; - embInfo.name = "table1"; - embInfo.devVocabSize = HBM_VOCAB_SIZE; - embInfo.hostVocabSize = DDR_VOCAB_SIZE; - embInfo.ssdVocabSize = SSD_VOCAB_SIZE; - embInfo.ssdDataPath = {"ssd_data"}; - vector embInfos; - embInfos.emplace_back(embInfo); - return embInfos; -} - -// 测试HBM与DDR换入换出时CacheManager模块频次刷新 -TEST(EmbHashMap, TestFindOffset) -{ - LOG_INFO("start TestFindOffset"); - string embTableName = "table1"; - EmbHashMap hostHashMaps; - RankInfo rankInfo; - rankInfo.isDDR = true; - auto embInfo = GetEmbInfoList(); - hostHashMaps.Init(rankInfo, embInfo, false); - CacheManager cacheManager; - cacheManager.Init(nullptr, embInfo); - bool isSSDEnabled = true; - hostHashMaps.isSSDEnabled = isSSDEnabled; - hostHashMaps.cacheManager = &cacheManager; - int channelId = 0; - size_t currentBatchId = 0; - size_t keepBatchId = 0; - int opTimes = 0; - - vector keys = {1, 2, 3, 4, 5}; - hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - - vector keys2 = {6, 7, 8, 9, 10}; - hostHashMaps.FindOffset(embTableName, keys2, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - - auto& excludeKeyMap = cacheManager.excludeDDRKeyCountMap[embTableName]; - auto& ddrKeyMap = cacheManager.ddrKeyFreqMap[embTableName]; - - auto logLevelTemp = Logger::GetLevel(); - Logger::SetLevel(Logger::TRACE); - vector keys4 = {21, 21, 21, 21}; // 新key重复值, 且需要换入换出 - hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - ASSERT_EQ(excludeKeyMap[INT_21], INT_4); - ASSERT_EQ(ddrKeyMap.Get(1), 1); - - keys4 = {41, 42, 43, 44, 45, 46, 47, 48, 49, 50}; // 整个hbm大小key换入换出 - hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - ASSERT_EQ(ddrKeyMap.Get(INT_21), INT_4); - - keys4 = {51, 52, 53, 1, 2, 21, 41, 42, 43, 44}; // 3个新key, 3个在ddr, 4个在hbm - hostHashMaps.FindOffset(embTableName, keys4, currentBatchId, keepBatchId, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes); - ASSERT_EQ(excludeKeyMap[1], INT_2); - ASSERT_EQ(excludeKeyMap[INT_42], INT_2); - ASSERT_EQ(ddrKeyMap.Get(INT_21), NEGATIVE_INT_1); - ASSERT_EQ(ddrKeyMap.Get(1), NEGATIVE_INT_1); - Logger::SetLevel(logLevelTemp); // 恢复日志级别 - LOG_INFO("test TestFindOffset end."); -} - -TEST(EmbHashMap, TESTGetHashMaps) -{ - string embTableName = "table1"; - EmbHashMap hostHashMaps; - RankInfo rankInfo; - rankInfo.isDDR = true; - auto embInfo = GetEmbInfoList(); - hostHashMaps.Init(rankInfo, embInfo, false); - CacheManager cacheManager; - cacheManager.Init(nullptr, embInfo); - hostHashMaps.isSSDEnabled = true; - hostHashMaps.cacheManager = &cacheManager; - int channelId = 0; - size_t currentBatchId = 0; - size_t keepBatchId = 0; - int opTimes = 0; - - vector keys = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - auto testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); - hostHashMaps.embHashMaps.at(embTableName).maxOffsetOld = testEmbHashMap.maxOffset; - // 增加10个key, offset长度变为10 - ASSERT_EQ(testEmbHashMap.maxOffset, 10); - - keys = {11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; - hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); - // 再增加10个key,offset变为20 - ASSERT_EQ(testEmbHashMap.maxOffset, 20); - - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - hybridMgmtBlock->lastRunChannelId = channelId; - hybridMgmtBlock->hybridBatchId[0] = 1; - testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); - // 回退一步,offset变回10 - ASSERT_EQ(testEmbHashMap.maxOffset, 10); - - hybridMgmtBlock->hybridBatchId[0] = 2; - // 回退2步,抛出异常 - ASSERT_THROW(hostHashMaps.GetHashMaps(), HybridMgmtBlockingException); - hybridMgmtBlock->hybridBatchId[0] = 0; - - keys = {10, 11}; - hostHashMaps.EvictDeleteEmb(embTableName, keys); - testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); - // 淘汰1个hbm key和1个ddr key,表中无法查找到该key - ASSERT_EQ(testEmbHashMap.hostHashMap.find(10), testEmbHashMap.hostHashMap.end()); - ASSERT_EQ(testEmbHashMap.hostHashMap.find(11), testEmbHashMap.hostHashMap.end()); - ASSERT_EQ(cacheManager.excludeDDRKeyCountMap[embTableName][11], 0); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(10), -1); - - keys = {1, 2}; - hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); - RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); - testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); - // 从ddr中换回2个key到hbm,交换变量长度为2 - ASSERT_EQ(testEmbHashMap.ddr2HbmKeys.size(), 2); - hostHashMaps.ClearLookupAndSwapOffset(hostHashMaps.embHashMaps.at(embTableName)); - testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); - // 清理后,交换变量长度为0 - ASSERT_EQ(testEmbHashMap.ddr2HbmKeys.size(), 0); -} \ No newline at end of file diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index e47f3b4f73288f0563c83b6b2d195297711ccda5..4924abf1024e0c2fbe578f4ffc22b8bc7285f467 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -15,7 +15,6 @@ See the License for the specific language governing permissions and #include #include "hybrid_mgmt/hybrid_mgmt.h" -#include "host_emb/host_emb.h" #include "utils/common.h" using namespace std; @@ -62,30 +61,6 @@ protected: string constantInitializerName = "constant_initializer"; int nBatch = 10; - void UpdateEmb(vector &missingKeysHostPos, int channelId, const string &embName, - std::unique_ptr &hostEmb, vector &d2h_emb) - { - LOG_INFO(HD + "update emb start"); - if (d2h_emb.size() == 0) { - LOG_INFO(HD + "emb is none channelId:{}", channelId); - return; - } - - auto tensorPtr = d2h_emb[0].flat().data(); - for (size_t i = 0; i < missingKeysHostPos.size(); i++) { - (hostEmb->GetEmb(embName).embData[missingKeysHostPos[i]]).assign( - tensorPtr, - tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize); - tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize; - } - for (size_t i = 0; i < hostEmb->GetEmb(embName).embData.size(); ++i) { - LOG_INFO("hostEmb: embName {}, {} is: {}", embName, i, - VectorToString(hostEmb->GetEmb(embName).embData[i])); - } - LOG_INFO(HD + "update emb end"); - d2h_emb.clear(); - } - bool Float2TensorVec(const vector>& Datas, vector& tensors) { tensors.clear(); @@ -116,63 +91,6 @@ protected: // delete } }; -#ifndef GTEST -TEST_F(EmbMgmtTest, Initialize) -{ - vector vocabsize = { devVocabSize, hostVocabSize }; - aoto param = EmbInfoParams(name, sendCount, embeddingSize, extEmbeddingSize, isSave) - embInfo = EmbInfo(param, vocabsize, initializeInfos); - embInfos.emplace_back(embInfo); - vector thresholdValues = {}; - - auto hybridMgmt = Singleton::GetInstance(); - cout << "setup..." << endl; - - allRank = RankInfo(GlogConfig::gRankId, deviceId, localRankSize, useStatic, nBatch, maxStep); - hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); - auto hostEmbs = make_unique(); - hostEmbs->Initialize(embInfos, seed); - auto hostHashMaps = make_unique(); - hostHashMaps->Init(allRank, embInfos, false); - - int currentBatchId = 0; - vector lookupKeys = { 1, 3, 5, 7 }; - vector d2h_emb; - vector> tmpDatas; - vector tmpData; - hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); - auto missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - LOG_INFO("missingKeys {}", missingKeys); - hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); - auto status = Float2TensorVec(tmpDatas, d2h_emb); - ASSERT_EQ(status, true); - UpdateEmb(missingKeys, 0, embInfo.name, hostEmbs, d2h_emb); - hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); - - lookupKeys = { 2, 3, 5, 6 }; - hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); - missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - LOG_INFO("missingKeys {}", missingKeys); - hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); - status = Float2TensorVec(tmpDatas, d2h_emb); - ASSERT_EQ(status, true); - UpdateEmb(missingKeys, 0, embInfo.name, hostEmbs, d2h_emb); - hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); - - lookupKeys = { 1, 7, 9, 10 }; - hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); - missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - LOG_INFO("missingKeys {}", missingKeys); - hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); - Float2TensorVec(tmpDatas, d2h_emb); - status = Float2TensorVec(tmpDatas, d2h_emb); - ASSERT_EQ(status, true); - UpdateEmb(missingKeys, 0, embInfo.name, hostEmbs, d2h_emb); - hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); - - hybridMgmt->Destroy(); -} -#endif #ifndef GTEST TEST_F(EmbMgmtTest, Initialize_HBM) diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp deleted file mode 100644 index b26b4487075a46af4c8a53ac1d55fa0c9ce0b9bb..0000000000000000000000000000000000000000 --- a/src/tests/emb_table/emb_table_test.cpp +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include "utils/common.h" -#include "emb_table/emb_table.h" - -using namespace std; -using namespace MxRec; -using namespace testing; -using namespace tensorflow; - -class EmbTableTest : public testing::Test { -protected: - void SetUp() - { - // 设置测试用的EmbInfo - embInfo.extEmbeddingSize = embTable.TEST_EMB_SIZE; - LOG_INFO("EmbTable BLOCK_EMB_COUNT {} INIT_BLOCK_COUNT {}", - embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); - rankInfo.rankId = 0; - rankInfo.rankSize = 1; - rankInfo.localRankSize = 1; - rankInfo.useStatic = true; - rankInfo.localRankId = 0; - rankInfo.isDDR = true; - rankInfo.ctrlSteps = { 1, -1 }; - rankInfo.deviceId = 0; - // 初始化EmbeddingTable -#ifndef GTEST - LOG_INFO("rank {} running", rankInfo.deviceId); - aclInit(nullptr); -#endif - } - - EmbTable embTable; - EmbInfo embInfo; - RankInfo rankInfo; - aclrtContext context; - - void TearDown() { - } -}; - -// 测试初始化是否正常 -TEST_F(EmbTableTest, Init) -{ -#ifndef GTEST - // 测试初始化是否出现异常 - EXPECT_NO_THROW(embTable.Init(embInfo, rankInfo, 0)); - LOG_INFO("embTable Init succeed!"); - ASSERT_EQ(embTable.rankInfo.g_rankId, rankInfo.g_rankId); - ASSERT_EQ(embTable.rankInfo.rankSize, rankInfo.rankSize); - ASSERT_EQ(embTable.rankInfo.localRankSize, rankInfo.localRankSize); - ASSERT_EQ(embTable.rankInfo.useStatic, rankInfo.useStatic); - ASSERT_EQ(embTable.rankInfo.localRankId, rankInfo.localRankId); - // 测试容量是否正常 - LOG_INFO("totalCapacity {}, INIT_BLOCK_COUNT {}", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); - EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT * embTable.BLOCK_EMB_COUNT); -#endif -} - -// 测试embedding list为空时的情况 -TEST_F(EmbTableTest, GetEmbAddressEmptyList) -{ -#ifndef GTEST - embTable.Init(embInfo, rankInfo, 0); - while (!embTable.embeddingList.empty()) { - float *embAddr = reinterpret_cast(embTable.GetEmbAddress()); - EXPECT_NE(embAddr, nullptr); - } - ASSERT_EQ(embTable.embeddingList.size(), 0); - - float *curAddr = nullptr; - int usedCapacityBefore = embTable.usedCapacity; - ASSERT_NO_THROW({ - curAddr= reinterpret_cast(embTable.GetEmbAddress()); - }); - EXPECT_NE(curAddr, nullptr); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); -#endif -} - -// 测试正常情况 -TEST_F(EmbTableTest, GetEmbAddressNormal) -{ -#ifndef GTEST - embTable.Init(embInfo, rankInfo, 0); - ASSERT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); - float *curAddr = nullptr; - int totalCapacityBefore = embTable.totalCapacity; - int usedCapacityBefore = embTable.usedCapacity; - ASSERT_NO_THROW({ - curAddr = reinterpret_cast(embTable.GetEmbAddress()); - }); - EXPECT_NE(curAddr, nullptr); - EXPECT_EQ(embTable.totalCapacity, totalCapacityBefore); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); -#endif -} - -// 测试将一个emb地址放入embeddingList中,是否成功 -TEST_F(EmbTableTest, PutEmbAddress) -{ -#ifndef GTEST - embTable.Init(embInfo, rankInfo, 0); - int64_t curAddr; - int usedCapacityBefore = embTable.usedCapacity; - ASSERT_NO_THROW({ - curAddr = embTable.GetEmbAddress(); - }); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); - embTable.PutEmbAddress(curAddr); - EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore); - EXPECT_EQ(curAddr, reinterpret_cast(embTable.embeddingList.back())); -#endif -} diff --git a/src/tests/emb_table/embedding_ddr_test.cpp b/src/tests/emb_table/embedding_ddr_test.cpp index 374a1392e9eec0c5886fe68ae56ad6d03a72ddc5..097167f6df3633b1766b5d8691683a7d6c5e209e 100644 --- a/src/tests/emb_table/embedding_ddr_test.cpp +++ b/src/tests/emb_table/embedding_ddr_test.cpp @@ -22,9 +22,7 @@ See the License for the specific language governing permissions and #include #include #include "utils/common.h" -#include "emb_table/emb_table.h" #include "emb_table/embedding_ddr.h" -#include "host_emb/host_emb.h" using namespace std; using namespace MxRec; @@ -36,8 +34,8 @@ protected: EmbeddingDDRTest() { struct EmbInfoParams embParam(string("test1"), 0, 1000, 2000, true, true); - std::vector vocabsize = {100}; - std::vector initializeInfos = {}; + std::vector vocabsize = {100, 100, 100}; + vector initializeInfos = {}; std::vector ssdDataPath = {""}; vector maxStep = {1000}; embInfo_ = EmbInfo(embParam, vocabsize, initializeInfos, ssdDataPath); @@ -75,79 +73,6 @@ protected: */ TEST_F(EmbeddingDDRTest, SaveLoadEmbeddingData) { - vector embInfos = {embInfo_}; - HostEmb* hostEmbs = Singleton::GetInstance(); - hostEmbs->Initialize(embInfos, 0); - HostEmbTable& table = hostEmbs->GetEmb("test1"); - - vector tmp1 {1.1, 2.1, 3.1}; - vector tmp2 {1.2, 2.2, 3.2}; - vector tmp3 {1.3, 2.3, 3.3}; - vector> testData; - testData.push_back(tmp1); - testData.push_back(tmp2); - testData.push_back(tmp3); - - for (vector& tmp : testData) { - table.embData.push_back(tmp); - } - - shared_ptr ddr1 = std::make_shared(embInfo_, rankInfo_, 0); - shared_ptr ddr2 = std::make_shared(embInfo_, rankInfo_, 0); - ddr1->Save("test_dir"); - // 修改成0 - for (vector& tmp: table.embData) { - for (float& t : tmp) { - t = 0; - } - } - bool fileExist = false; - if (access("./test_dir/test1/embedding", F_OK) == 0) { - fileExist = true; - } - EXPECT_EQ(fileExist, true); -} - -/** - * 测试基本查找 - */ -TEST_F(EmbeddingDDRTest, DDRBasic) -{ - shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); - const size_t testNum = 100; - vector testKeys; - vector testSwap; - for (size_t i = 0; i < testNum; ++i) { - testKeys.push_back(i); - } - table->FindOffset(testKeys, 0, TRAIN_CHANNEL_ID, testSwap); - EXPECT_EQ(testKeys.size(), 100); - EXPECT_EQ(testSwap.size(), 0); -} - -TEST_F(EmbeddingDDRTest, evict) -{ - shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); - const size_t testNum = 100; - vector testKeys; - vector testSwap; - for (size_t i = 0; i < testNum; ++i) { - testKeys.push_back(i); - } - table->FindOffset(testKeys, 0, TRAIN_CHANNEL_ID, testSwap); - table->EvictKeys(testKeys); - EXPECT_EQ(table->evictDevPos.size(), 100); - EXPECT_EQ(testKeys.size(), 100); - EXPECT_EQ(testSwap.size(), 0); -} - -TEST_F(EmbeddingDDRTest, FindSwap) -{ - shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); - const size_t testNum = 100; - vector testSwap; - table->FindSwapPosOld(0, 0, 0, testSwap); - EXPECT_EQ(testSwap.size(), 1); } TEST_F(EmbeddingDDRTest, EvictDeleteEmb) diff --git a/src/tests/emb_table/embedding_mgmt_test.cpp b/src/tests/emb_table/embedding_mgmt_test.cpp index 9374b0786fae783c04b8cf068ee1064b0816f218..81a354bf71edcb2edeb67170ac8e42c3bdf09b6b 100644 --- a/src/tests/emb_table/embedding_mgmt_test.cpp +++ b/src/tests/emb_table/embedding_mgmt_test.cpp @@ -22,7 +22,6 @@ See the License for the specific language governing permissions and #include #include #include "utils/common.h" -#include "emb_table/emb_table.h" #include "emb_table/embedding_mgmt.h" using namespace std; @@ -35,8 +34,8 @@ protected: EmbeddingMgmtTest() { struct EmbInfoParams embParam(string("test1"), 0, 1000, 2000, true, true); - std::vector vocabsize = {100}; - std::vector initializeInfos = {}; + std::vector vocabsize = {100, 100, 100}; + vector initializeInfos = {}; std::vector ssdDataPath = {""}; vector maxStep = {1000}; embInfo_ = EmbInfo(embParam, vocabsize, initializeInfos, ssdDataPath); @@ -75,7 +74,7 @@ TEST_F(EmbeddingMgmtTest, Init) ThresholdValue thvalue(tableName, 0, 0, 0, false); vector embInfos = {embInfo_}; vector thresholds = {thvalue}; - EmbeddingMgmt::Instance()->Init(rankInfo_, embInfos, thresholds, 0); + EmbeddingMgmt::Instance()->Init(rankInfo_, embInfos, 0); constexpr int testNum = 100; vector testKeys; @@ -95,7 +94,7 @@ TEST_F(EmbeddingMgmtTest, GetAttributes) ThresholdValue thvalue(tableName, 0, 0, 0, false); vector embInfos = {embInfo_}; vector thresholds = {thvalue}; - EmbeddingMgmt::Instance()->Init(rankInfo_, embInfos, thresholds, 0); + EmbeddingMgmt::Instance()->Init(rankInfo_, embInfos, 0); constexpr int testNum = 100; vector testKeys; diff --git a/src/tests/emb_table/embedding_static_test.cpp b/src/tests/emb_table/embedding_static_test.cpp index 09e72ca0a815746220ef4bdc5dd8056a087bbf4d..5d1f0ab769b4532861bf02bd96760f8dcce540b7 100644 --- a/src/tests/emb_table/embedding_static_test.cpp +++ b/src/tests/emb_table/embedding_static_test.cpp @@ -21,7 +21,6 @@ See the License for the specific language governing permissions and #include #include #include "utils/common.h" -#include "emb_table/emb_table.h" #include "emb_table/embedding_static.h" using namespace std; @@ -34,8 +33,8 @@ protected: EmbeddingStaticTest() { struct EmbInfoParams embParam(string("test1"), 0, 1000, 2000, true, true); - std::vector vocabsize = {100}; - std::vector initializeInfos = {}; + std::vector vocabsize = {100, 100, 100}; + vector initializeInfos = {}; std::vector ssdDataPath = {""}; vector maxStep = {1000}; embInfo_ = EmbInfo(embParam, vocabsize, initializeInfos, ssdDataPath); @@ -136,7 +135,8 @@ TEST_F(EmbeddingStaticTest, Key2OffsetEvict) } table->Key2Offset(testData, TRAIN_CHANNEL_ID); // 全部淘汰 - table->EvictKeys(testData); + vector testDataAdapt(testData.cbegin(), testData.cend()); + table->EvictKeys(testDataAdapt); vector new_data; for (size_t i = 0; i < testNum; ++i) { @@ -155,6 +155,7 @@ TEST_F(EmbeddingStaticTest, SaveKeyData) { vector embInfos = {embInfo_}; shared_ptr hbm = std::make_shared(embInfo_, rankInfo_, 0); + hbm->SetFileSystemPtr("test_dir"); hbm->Save("test_dir"); bool fileExist = false; if (access("./test_dir/test1/key", F_OK) == 0) { diff --git a/src/tests/file_system/hdfs_file_system_test.cpp b/src/tests/file_system/hdfs_file_system_test.cpp index a8c8bbf5a2e2307279b05b89f3312168b8a27e03..98f733f0a9efc998c37afac9fda2ad2c20b3af8a 100644 --- a/src/tests/file_system/hdfs_file_system_test.cpp +++ b/src/tests/file_system/hdfs_file_system_test.cpp @@ -17,7 +17,6 @@ See the License for the specific language governing permissions and #include #include "file_system/file_system_handler.h" -#include "file_system/hdfs_file_system/hdfs_file_system.h" #include "file_system/hdfs_file_system/hdfs_wrapper.h" using namespace std; @@ -27,10 +26,10 @@ using namespace emock; void MockHdfs() { + EMOCK(&HdfsWrapper::LoadHdfsLib).stubs().will(ignoreReturnValue()); hdfsFS ConnectFs; hdfsFile hdfsFileHandler; hdfsFileInfo* fileInfo; - EMOCK(&HdfsWrapper::LoadHdfsLib).stubs().will(ignoreReturnValue()); EMOCK(&HdfsWrapper::CloseHdfsLib).stubs().will(ignoreReturnValue()); EMOCK(&HdfsWrapper::Connect).stubs().will(returnValue(ConnectFs)); EMOCK(&HdfsWrapper::Disconnect).stubs().will(returnValue(1)); @@ -38,8 +37,6 @@ void MockHdfs() EMOCK(&HdfsWrapper::FreeFileInfo).stubs().will(ignoreReturnValue()); EMOCK(&HdfsWrapper::OpenFile).stubs().will(returnValue(hdfsFileHandler)); EMOCK(&HdfsWrapper::CloseFile).stubs().will(returnValue(1)); - EMOCK(&HdfsWrapper::Write).stubs().will(returnValue(1)); - EMOCK(&HdfsWrapper::Read).stubs().will(returnValue(1)); EMOCK(&HdfsWrapper::Seek).stubs().will(returnValue(1)); } @@ -78,31 +75,11 @@ TEST_F(HdfsFileSystemTest, CreateDirFailed) TEST_F(HdfsFileSystemTest, GetFileSize) { - hdfsFileInfo* fileInfo; - EMOCK(&HdfsWrapper::GetPathInfo).stubs().will(returnValue(fileInfo)); + std::unique_ptr fileInfo = std::make_unique(); + EMOCK(&HdfsWrapper::GetPathInfo).stubs().will(returnValue(fileInfo.get())); string filePath = "hdfs://master:9000/test_dir/"; auto fileSystemHandler = make_unique(); auto fileSystemPtr = fileSystemHandler->Create(filePath); EXPECT_NO_THROW(fileSystemPtr->GetFileSize(filePath)); } -TEST_F(HdfsFileSystemTest, testCase) -{ - string filePath = "hdfs://master:9000/test_dir/"; - auto fileSystemHandler = make_unique(); - auto fileSystemPtr = fileSystemHandler->Create(filePath); - - vector dirs; - dirs = fileSystemPtr->ListDir(filePath); - EXPECT_EQ(dirs.size(), 0); - - vector writeData = {0, 1, 2, 3, 4, 5}; - size_t testDataSize = writeData.size() * sizeof(int64_t); - EXPECT_NO_THROW(fileSystemPtr->Write(filePath, reinterpret_cast(writeData.data()), testDataSize)); - float p[5] = {1.1, 2.2, 3.3, 4.4, 5.5}; - vector writeData1 = {p, p+1, p+2, p+3, p+4}; - EXPECT_NO_THROW(fileSystemPtr->Write(filePath, writeData1, sizeof(float))); - - vector readData = {}; - EXPECT_NO_THROW(fileSystemPtr->Read(filePath, reinterpret_cast(readData.data()), 1)); -} \ No newline at end of file diff --git a/src/tests/file_system/local_file_system_test.cpp b/src/tests/file_system/local_file_system_test.cpp index dfe5d483f27010839333a27860b4d53c6c6939d1..2ea0d9d3291fb56e1053dce1eecba6480de206c9 100644 --- a/src/tests/file_system/local_file_system_test.cpp +++ b/src/tests/file_system/local_file_system_test.cpp @@ -16,7 +16,6 @@ See the License for the specific language governing permissions and #include #include "file_system/file_system_handler.h" -#include "file_system/local_file_system/local_file_system.h" using namespace std; using namespace MxRec; @@ -42,10 +41,10 @@ TEST(LocalFileSystem, WriteAndReadFile) TEST(LocalFileSystem, WriteEmbedding) { string filePath = "./write.data"; - float p[5] = {1.1, 2.2, 3.3, 4.4, 5.5}; - vector writeData = {p, p+1, p+2, p+3, p+4}; + vector writeData = {1.1, 2.2, 3.3, 4.4, 5.5}; + vector> writeData1 = {writeData}; auto fileSystemHandler = make_unique(); auto fileSystemPtr = fileSystemHandler->Create(filePath); - ssize_t res = fileSystemPtr->Write(filePath, writeData, sizeof(float)); + ssize_t res = fileSystemPtr->Write(filePath, writeData1, sizeof(float)); ASSERT_EQ(writeData.size() * sizeof(float), res); } diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp deleted file mode 100644 index 05a636d913650d501d3af4ed15fa245767cfc429..0000000000000000000000000000000000000000 --- a/src/tests/host_emb/host_emb_test.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and - limitations under the License. -==============================================================================*/ - -#include -#include - -#include "host_emb/host_emb.h" -#include "tensorflow/core/framework/tensor.h" -#include "hd_transfer/hd_transfer.h" -#include "utils/singleton.h" - -using namespace std; -using namespace tensorflow; -using namespace MxRec; - -namespace { -bool operator==(const Tensor& tensor1, const Tensor& tensor2) -{ - if (tensor1.shape() != tensor2.shape()) { - return false; - } - auto tensor1_data = tensor1.flat(); - auto tensor2_data = tensor2.flat(); - for (int j = 0; j < tensor1_data.size(); j++) { - if (tensor1_data(j) != tensor2_data(j)) { - return false; - } - } - return true; -} - -bool operator==(const vector& p1, const vector& p2) -{ - if (p1.size() != p2.size()) { - return false; - } - for (int i = 0; i>> lookups; - vector host_emb; - host_emb.resize(15); - vector> p(5, vector(3)); - host_emb[0] = 1; - host_emb[1] = 3; - std::cout << host_emb[0] << std::endl; - for (int i = 0; i < 5; i++) { - p[i].assign(host_emb.begin() + i * 3, host_emb.begin() + (i + 1) * 3); - } - std::cout << p[0][0] << std::endl; - std::cout << '5' << std::endl; - vector q; - std::cout << '0' << std::endl; - for (int i = 0; i < 2; i++) { - Tensor tmpTensor(tensorflow::DT_INT32, { 3 }); - std::cout << '1' << std::endl; - auto tmpData = tmpTensor.flat(); - std::cout << '2' << std::endl; - for (int j = 0; j < 3; j++) { - tmpData(j) = p[i][j]; - std::cout << '3' << std::endl; - } - - q.emplace_back(tmpTensor); - std::cout << '4' << std::endl; - } - std::cout << '1' << std::endl; - std::cout << q[0].flat()(0) << std::endl; - std::cout << q[0].flat()(1) << std::endl; - std::cout << q[1].flat()(0) << std::endl; - ASSERT_EQ(1, 1); -} - -TEST(HostEmb, DefaultConstructor) -{ - HostEmb h; - h.procThreadsForTrain.emplace_back(make_unique([] {})); - h.Join(TRAIN_CHANNEL_ID); - ASSERT_EQ(h.procThreadsForTrain.size(), 0); - - h.procThreadsForEval.emplace_back(make_unique([] {})); - h.Join(EVAL_CHANNEL_ID); - ASSERT_EQ(h.procThreadsForEval.size(), 0); -} - -} \ No newline at end of file diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 09cadc7feada287b93dbb42b4419acee435e4373..dffce96cc3fa0244de37b69052420d47fe143054 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -248,7 +248,7 @@ protected: currTime = time(nullptr); if (currTime - lastTime >= SleepTime::SLEEP_SECOND_4) { LOG_INFO("Evict-thread doing at currTime[{}] ...", currTime); - map> evictPosMap {}; + map> evictPosMap {}; faae.FeatureEvict(evictPosMap); lastTime = currTime; } @@ -258,7 +258,7 @@ protected: } void WaitEvictThread() { - map> evictPosMap {}; + map> evictPosMap {}; faae.FeatureEvict(evictPosMap); // 退出前保证执行了一次“淘汰” isExitFlag = true; if (evictThr.joinable()) { diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index a5e618cd08e809420d452a78b8a5dbb9ce2f0e24..fb2be40bf8320792b907dae58a872dd6d0a1f8aa 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -23,13 +23,13 @@ See the License for the specific language governing permissions and #include "ock_ctr_common/include/unique.h" #include "ock_ctr_common/include/error_code.h" #include "emb_table/embedding_mgmt.h" +#include "emock/emock.hpp" using namespace std; using namespace MxRec; using namespace testing; static constexpr size_t BATCH_NUM_EACH_THREAD = 3; -ock::ctr::FactoryPtr factory; class SimpleThreadPool { public: @@ -45,21 +45,13 @@ public: } }; -static void CTRLog(int level, const char *msg) -{ - switch (level) { - case 0: - LOG_DEBUG(msg); - break; - default: - break; - } -} - class KeyProcessTest : public testing::Test { protected: void SetUp() { + int defaultUBSize = 196608; + EMOCK(GetUBSize).stubs().with(any()).will(returnValue(defaultUBSize)); + int claimed; MPI_Query_thread(&claimed); ASSERT_EQ(claimed, MPI_THREAD_MULTIPLE); @@ -76,7 +68,6 @@ protected: rankInfo.isDDR = false; rankInfo.useDynamicExpansion = false; rankInfo.ctrlSteps = { 1, -1 }; - rankInfo.useHot = false; // 初始化emb信息 GenEmbInfos(embNum, embInfos, fieldNums); splits = fieldNums; @@ -318,6 +309,7 @@ protected: void TearDown() { // delete + GlobalMockObject::reset(); } }; @@ -348,7 +340,7 @@ TEST_F(KeyProcessTest, Start) ASSERT_EQ(process.Start(), 0); setenv("keyProcessThreadNum", "abc", 1); ASSERT_EQ(process.Start(), 0); - CTRLog(0, "key process start successful"); + LOG_INFO("key process start successful"); process.Destroy(); } @@ -424,34 +416,6 @@ TEST_F(KeyProcessTest, PaddingHashSplitWithFAAE) } } -TEST_F(KeyProcessTest, HotHashSplit) -{ - PrepareBatch(); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - LOG_INFO("CPU Core Num: %{}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 - - auto fn = [this](int channel, int id) { - auto embName = embInfos[0].name; - process.hotEmbTotCount[embName] = 10; - vector splitKeys; - vector restore; - vector hotPos; - unique_ptr batch; - batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); - tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); - }; // for clean code - for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < 1; ++id) { - // use lambda expression initialize thread - process.procThreads.emplace_back(std::make_unique(fn, channel, id)); - } - } - this_thread::sleep_for(10s); - process.Destroy(); -} - TEST_F(KeyProcessTest, GetScAll) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 @@ -523,38 +487,6 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) ASSERT_THAT(restore, ElementsAreArray(allExpectRestore[worldRank])); } -// hot模式,batch随机数,ProcessSplitKeys后人为校验lookupKeys、scAll、restore -TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) -{ - PrepareBatch(); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 - - auto fn = [this](int channel, int id) { - auto embName = embInfos[0].name; - vector splitKeys; - vector restore; - vector hotPos; - unique_ptr batch; - batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - LOG_INFO("rankid :{}, batchid: {}", rankInfo.rankId, batch->batchId); - tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - auto [lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); - process.BuildRestoreVec(batch, ss, restore, hotPos.size()); - LOG_INFO("rankid :{}, batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", - rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), - VectorToString(scAll), VectorToString(restore)); - }; // for clean code - for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { - // use lambda expression initialize thread - process.procThreads.emplace_back(std::make_unique(fn, channel, id)); - } - } - this_thread::sleep_for(10s); - process.Destroy(); -} - // 准入模式,batch随机数,ProcessSplitKeys后人为校验lookupKeys、scAll、count TEST_F(KeyProcessTest, GetCountRecv) { @@ -634,116 +566,6 @@ TEST_F(KeyProcessTest, GetUniqueConfig) process.GetUniqueConfig(uniqueConf); } -// HBM端到端测试,动态shape,固定batch输入 -TEST_F(KeyProcessTest, KeyProcessTaskHelper) -{ - rankInfo.isDDR = false; - rankInfo.useStatic = false; - rankInfo.useHot = false; - rankInfo.useDynamicExpansion = false; - EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - ASSERT_EQ(process.isRunning, true); - int batchId = 0; - int channelId = 0; - auto batch = GenBatch(embInfos[0].name, batchId, channelId); // 测试一个表 - - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batchSize: {}", - rankInfo.rankId, batch->batchId, batch->sample.size()); - - ASSERT_EQ(process.KeyProcessTaskHelper(batch, channelId, 0), true); // threadId = 0 - auto infoVecs = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE); - ASSERT_NE(infoVecs, nullptr); - auto all2all = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::ALL2ALL); - ASSERT_NE(all2all, nullptr); - - ASSERT_EQ(CheckMatrixTensor(*all2all, allExpectAll2all), true); - ASSERT_EQ(CheckFlatTensor({infoVecs->back()}, allExpectOffset[worldRank]), true); - infoVecs->pop_back(); - ASSERT_EQ(CheckFlatTensor(*infoVecs, allExpectRestore[worldRank]), true); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, normal status success", rankInfo.rankId, batch->batchId); - // 测试batchId错误 - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - hybridMgmtBlock->hybridBatchId[0] = 1; - ASSERT_EQ(process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE), nullptr); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batchId exception success", - rankInfo.rankId, batch->batchId); - // 测试empty场景 - hybridMgmtBlock->pythonBatchId[1] = 1; - hybridMgmtBlock->hybridBatchId[1] = 1; - hybridMgmtBlock->readEmbedBatchId[1] = 1; - hybridMgmtBlock->loop[1] = 1; - ASSERT_EQ(process.GetInfoVec(batchId + 1, embInfos[0].name, channelId + 1, ProcessedInfo::RESTORE), nullptr); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batch empty success", rankInfo.rankId, batch->batchId); - // eos - process.SetEos(1, 1); - ASSERT_EQ(process.GetInfoVec(batchId + 1, embInfos[0].name, channelId + 1, ProcessedInfo::RESTORE), nullptr); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, eos status success", rankInfo.rankId, batch->batchId); - this_thread::sleep_for(10s); - process.Destroy(); -} - -// DDR端到端测试,静态shape,固定batch输入 -TEST_F(KeyProcessTest, KeyProcessTaskHelperDDR) -{ - rankInfo.isDDR = true; - rankInfo.useStatic = true; - rankInfo.useHot = false; - rankInfo.useDynamicExpansion = false; - EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - ASSERT_EQ(process.isRunning, true); - int batchId = 0; - int channelId = 0; - auto batch = GenBatch(embInfos[0].name, batchId, channelId); // 测试第一个表 - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - hybridMgmtBlock->hybridBatchId[0] = 0; - LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}", rankInfo.rankId, batch->batchId); - - ASSERT_EQ(process.KeyProcessTaskHelper(batch, channelId, 0), true); // threadId = 0 - - auto lookupKeys = process.GetLookupKeys(batchId, embInfos[0].name, channelId); // lookup list返回的不是tensor - ASSERT_EQ(lookupKeys.size(), sendCount * worldSize); - LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}, lookupKeys: {}", - rankInfo.rankId, batch->batchId, VectorToString(lookupKeys)); - ASSERT_EQ(CheckPaddingVec(lookupKeys, allExpectLookupKeys[worldRank]), true); - - auto infoVecs = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE); - ASSERT_NE(infoVecs, nullptr); - int col = allExpectRestore[worldRank].size(); - auto tmpTensor = (*infoVecs).at(0); - auto tmpData = tmpTensor.flat(); - - vector actualGetRestore(col); - for (int j = 0; j < col; j++) { - actualGetRestore[j] = tmpData(j); - } - LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}, Restore: {}", - rankInfo.rankId, batch->batchId, VectorToString(actualGetRestore)); - ASSERT_THAT(actualGetRestore, ElementsAreArray(allExpectRestoreStatic[worldRank])); - LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}, normal status success", - rankInfo.rankId, batch->batchId); - - // 测试batchId错误 - hybridMgmtBlock->hybridBatchId[0] = 1; - ASSERT_EQ(process.GetLookupKeys(batchId, embInfos[0].name, channelId).empty(), true); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batchId exception success", - rankInfo.rankId, batch->batchId); - // 测试empty场景 - hybridMgmtBlock->pythonBatchId[1] = 1; - hybridMgmtBlock->hybridBatchId[1] = 1; - hybridMgmtBlock->readEmbedBatchId[1] = 1; - hybridMgmtBlock->loop[1] = 1; - ASSERT_EQ(process.GetLookupKeys(batchId + 1, embInfos[0].name, channelId + 1).empty(), true); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batch empty success", rankInfo.rankId, batch->batchId); - // eos - process.SetEos(1, 1); - ASSERT_EQ(process.GetLookupKeys(batchId + 1, embInfos[0].name, channelId + 1).empty(), true); - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, eos status success", rankInfo.rankId, batch->batchId); - this_thread::sleep_for(10s); - process.Destroy(); -} - TEST_F(KeyProcessTest, InitializeUnique) { ASSERT_EQ(ock::ctr::Factory::Create(factory), -1); diff --git a/src/tests/leaks.supp b/src/tests/leaks.supp new file mode 100644 index 0000000000000000000000000000000000000000..ebe0718d956c0f62b60ae05e6782dfe59133f8b3 --- /dev/null +++ b/src/tests/leaks.supp @@ -0,0 +1,21 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# There are known leaks. +# 1.known mpi leaks. +leak:libmpi.so* +leak:libopen-pal.so* +leak:libpmix.so* +leak:libc.so* \ No newline at end of file diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 677939d2c68a376cd29f162cabfa505cfee7c0e9..164e667a28b05c3c064d08ad1e1b42cd63b94aba 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -18,11 +18,9 @@ See the License for the specific language governing permissions and #include #include "absl/container/flat_hash_map.h" -#include "host_emb/host_emb.h" -#include "ssd_cache/lfu_cache.h" -#include "ssd_cache/cache_manager.h" +#include "l3_storage/lfu_cache.h" +#include "l3_storage/cache_manager.h" #include "utils/common.h" -#include "emb_table/embedding_ddr.h" using namespace std; using namespace MxRec; @@ -36,16 +34,21 @@ void InitSSDEngine(CacheManager& manager, string embTableName, uint64_t ssdSize) { // Init ssd engine data chrono::seconds period = chrono::seconds(120); - manager.ssdEngine->SetCompactPeriod(period); - manager.ssdEngine->SetCompactThreshold(1); - manager.ssdEngine->CreateTable(embTableName, {SSD_SAVE_PATH}, ssdSize); - vector ssdKeys = {15, 25}; // 预设15, 25存储在SSD - std::vector> ssdEmbData = {{15.0f}, - {25.0f}}; - auto& excludeMap = manager.excludeDDRKeyCountMap[embTableName]; + auto ssdEngine = static_pointer_cast(manager.l3Storage); + ssdEngine->SetCompactPeriod(period); + ssdEngine->SetCompactThreshold(1); + ssdEngine->CreateTable(embTableName, {SSD_SAVE_PATH}, ssdSize); + vector ssdKeys = {15, 25}; // 预设15, 25存储在SSD + auto emb1 = new float(15.0f); + auto emb2 = new float(25.0f); + uint64_t extEmbeddingSize = 1; + std::vector ssdEmbData = {{emb1}, {emb2}}; + auto& excludeMap = manager.preProcessMapper[embTableName].excludeDDRKeyCountMap; excludeMap[15] = 3; // 初始化次数 excludeMap[25] = 5; - manager.ssdEngine->InsertEmbeddings(embTableName, ssdKeys, ssdEmbData); + ssdEngine->InsertEmbeddingsByAddr(embTableName, ssdKeys, ssdEmbData, extEmbeddingSize); + delete emb1; + delete emb2; } void InitDDREmbData(absl::flat_hash_map& loadData, string& embTableName, @@ -94,7 +97,7 @@ protected: LFUCache cache2; cacheManager.ddrKeyFreqMap[embTableName2] = cache2; PutKeyInfo(cacheManager.ddrKeyFreqMap[embTableName2], input_keys); - unordered_map excludeDDRKeyFreq; + unordered_map excludeDDRKeyFreq; excludeDDRKeyFreq[27] = 10; excludeDDRKeyFreq[30] = 10; cacheManager.excludeDDRKeyCountMap[embTableName] = excludeDDRKeyFreq; @@ -105,14 +108,14 @@ protected: InitDDREmbData(loadData, embTableName, mgmtEmbInfos); InitDDREmbData(loadData, embTableName2, mgmtEmbInfos); - cacheManager.Init(hEmb, mgmtEmbInfos); + ock::ctr::EmbCacheManagerPtr embCachePtr = nullptr; + + auto ssdEngine = make_shared(); + cacheManager.Init(embCachePtr, mgmtEmbInfos, ssdEngine); InitSSDEngine(cacheManager, embTableName, 5); InitSSDEngine(cacheManager, embTableName2, 10); // load ddr emb data - cacheManager.hostEmbs->hostEmbs = loadData; - - auto& embMap = cacheManager.hostEmbs->hostEmbs; } CacheManager cacheManager; @@ -126,49 +129,12 @@ protected: vector input_keys = {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 6, 6, 8, 9}; string embTableName = "table1"; string embTableName2 = "table2"; - HostEmb* hEmb = Singleton::GetInstance(); void TearDown() { } }; -TEST_F(CacheManagerTest, RefreshFreqInfo) -{ - vector ddr2HbmKeys = {8, 9}; - cacheManager.RefreshFreqInfoCommon(embTableName, ddr2HbmKeys, TransferType::DDR_2_HBM); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].keyTable.size(), 5); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].freqTable.size(), 2); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(8), -1); - ASSERT_EQ(cacheManager.excludeDDRKeyCountMap[embTableName].size(), 6); - - // HBM转移到DDR 频次数据设置构造 - cacheManager.excludeDDRKeyCountMap[embTableName][150] = 4; - cacheManager.excludeDDRKeyCountMap[embTableName][151] = 1; - vector hbm2DdrKeys = {150, 151}; - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(151), -1); - cacheManager.RefreshFreqInfoCommon(embTableName, hbm2DdrKeys, TransferType::HBM_2_DDR); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(150), 4); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(151), 1); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); - ASSERT_EQ(cacheManager.excludeDDRKeyCountMap[embTableName].size(), 6); - - vector ddr2EvictKeys = {151}; - cacheManager.RefreshFreqInfoCommon(embTableName, ddr2EvictKeys, TransferType::DDR_2_EVICT); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(151), -1); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].freqTable.size(), 3); - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); - - // HBM2Evict - cacheManager.excludeDDRKeyCountMap[embTableName][160] = 1; - vector hbm2EvictKeys = {160}; - cacheManager.RefreshFreqInfoCommon(embTableName, hbm2EvictKeys, TransferType::HBM_2_EVICT); - const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(160); - ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); - LOG_INFO("test RefreshFreqInfo end."); -} - TEST_F(CacheManagerTest, PutKey) { vector putDDRKeys = {1, 9, 8, 15}; @@ -181,236 +147,33 @@ TEST_F(CacheManagerTest, PutKey) LOG_INFO("test PutKey end."); } -TEST_F(CacheManagerTest, IsKeyInSSD) +TEST_F(CacheManagerTest, IsKeyInL3Storage) { vector checkKeys = {1, 2, 15, 25}; - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, checkKeys[0])); - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, checkKeys[1])); - ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, checkKeys[2])); - ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, checkKeys[3])); - LOG_INFO("test IsKeyInSSD end."); -} - -TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalKey) -{ - EmbeddingDDR table; - - vector currentKeys = {55, 65, 75}; - table.keyOffsetMap[55] = 119; - table.keyOffsetMap[65] = 118; - table.keyOffsetMap[75] = 116; - - TableInfo ti = table.GetTableInfo(); - - auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); - ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - LOG_INFO("test TransferDDREmbWithSSDByEmptyExternalKey end."); -} - -TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) -{ - vector ssdKeys = {15, 25}; - vector> ssdKeyEmbInfo = {{1.5f}, {2.5f}}; - - // init EmbeddingDDR - EmbeddingDDR table; - table.name = embTableName; - table.devVocabSize = 20; - table.hostVocabSize = 100; - table.maxOffset = 118; - table.evictHostPos.emplace_back(110); // 淘汰列表 - - TableInfo ti = table.GetTableInfo(); - - // 构造已经存储早DDR中key和offset对应关系; DDR的offset在映射表中范围是 20~119 - table.keyOffsetMap[9] = 117; // DDR中相对位置: 97 - table.keyOffsetMap[8] = 116; // DDR中相对位置: 96 - table.keyOffsetMap[6] = 114; // DDR中相对位置: 94 - table.keyOffsetMap[4] = 112; // DDR中相对位置: 92 - table.keyOffsetMap[3] = 111; // DDR中相对位置: 91 - table.keyOffsetMap[2] = 21; // DDR中相对位置: 1 - table.keyOffsetMap[1] = 20; // DDR中相对位置: 0 - - // 检查构造数据正确性 - auto& embMap = cacheManager.hostEmbs->hostEmbs; - const auto& it = embMap.find(embTableName); - auto& hostData = it->second.embData; - ASSERT_TRUE(fabs(hostData[0][0] - 1.0f) < EPSILON); - ASSERT_TRUE(fabs(hostData[1][0] - 2.0f) < EPSILON); - ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); - ASSERT_TRUE(fabs(hostData[97][0] - 9.0f) < EPSILON); - auto& excludeKeyCountMap = cacheManager.excludeDDRKeyCountMap[embTableName]; - ASSERT_EQ(excludeKeyCountMap[15], 3); - ASSERT_EQ(excludeKeyCountMap[25], 5); - ASSERT_FALSE(cacheManager.ssdEngine->IsKeyExist(embTableName, 9)); - ASSERT_FALSE(cacheManager.ssdEngine->IsKeyExist(embTableName, 8)); - ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 15)); - - // externalKeys: SSD(15, 25) + newKey(55, 65, 75) - // 训练场景,构造结果:offsetAvailableSize=20+100-118+evictPos.size()=3 - // cacheManager中的频次数据(低-高): 9 8 6 4 3 2 1 - // 构造空间超出SSD可用上限 - vector exceedKeys = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; - auto spaceError1 = cacheManager.TransferDDREmbWithSSD(ti, exceedKeys, TRAIN_CHANNEL_ID); - ASSERT_EQ(spaceError1, TransferRet::SSD_SPACE_NOT_ENOUGH); - - // 构造训练+超SSD可用+当前批次中不包含报错在SSD的key - vector keys2 = {6, 4, 55, 65, 75, 85, 95, 105, 115, 125, 135}; - auto spaceError2 = cacheManager.TransferDDREmbWithSSD(ti, exceedKeys, TRAIN_CHANNEL_ID); - ASSERT_EQ(spaceError2, TransferRet::SSD_SPACE_NOT_ENOUGH); - - // 构造当前批次key 存储位置: SSD(15, 25) DDR(6, 4) newKey(55, 65, 75) - vector currentKeys = {15, 25, 6, 4, 55, 65, 75}; - // 需要从ddr转移4个key到ssd, 低频数据中6 4在当前批次key中,不会被转移,构造的数据转移key:9, 8, 3, 2 - auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); - - // 检查处理后数据正确性 - ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); // DDR内未移动的数据 - ASSERT_TRUE(fabs(hostData[96][0] - 25.0f) < EPSILON); // SSD转移到DDR的数据 - ASSERT_TRUE(fabs(hostData[97][0] - 15.0f) < EPSILON); // SSD转移到DDR的数据 - ASSERT_EQ(table.evictHostPos.size(), 1); - ASSERT_EQ(table.evictHostPos.back(), 110); - - // 原DDR中最小频次key(9,8)次数(1)被转移到SSD,SSD转移到DDR的key(15,25)次数(3,5), DDR内频次索引应变为2 - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); - ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 9)); - ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 8)); - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 15)); - LOG_INFO("test TransferDDREmbWithSSDByAllProcess end."); -} - -TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalSSDKey) -{ - // 训练+评估:构造DDR剩余空间足够,externalSSDKeys为空 - EmbeddingDDR table; - table.name = embTableName; - table.devVocabSize = 20; - table.hostVocabSize = 100; - table.keyOffsetMap[6] = 114; // DDR中相对位置: 94 - table.keyOffsetMap[4] = 112; // DDR中相对位置: 92 - // 剩余3个可用空间(DDR剩余2个, 相对位置:98 99; DDR淘汰列表1个) - table.maxOffset = 118; - table.evictHostPos.emplace_back(110); - - TableInfo ti = table.GetTableInfo(); - - vector currentKeys = {6, 4, 55, 65, 75}; - auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); - ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - auto retByEval = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, EVAL_CHANNEL_ID); - ASSERT_EQ(retByEval, TransferRet::TRANSFER_OK); - - // 评估场景, DDR剩余空间不足, externalSSDKeys为空 - vector currentKeys2 = {6, 4, 55, 65, 75, 85, 95, 105, 115}; - auto ret2 = cacheManager.TransferDDREmbWithSSD(ti, currentKeys2, EVAL_CHANNEL_ID); - ASSERT_EQ(ret2, TransferRet::TRANSFER_OK); - // 训练场景,返回ssd空间不足 - auto ret3 = cacheManager.TransferDDREmbWithSSD(ti, currentKeys2, TRAIN_CHANNEL_ID); - ASSERT_EQ(ret3, TransferRet::SSD_SPACE_NOT_ENOUGH); - LOG_INFO("test TransferDDREmbWithSSDByEmptyExternalSSDKey end."); -} - -TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) -{ - // 评估+DDR剩余空间足够+externalSSDKeys为空 - EmbeddingDDR table; - table.name = embTableName; - table.devVocabSize = 20; - table.hostVocabSize = 100; - table.keyOffsetMap[9] = 117; // DDR中相对位置: 97 - table.keyOffsetMap[8] = 116; // DDR中相对位置: 96 - table.keyOffsetMap[6] = 114; // DDR中相对位置: 94 - table.keyOffsetMap[4] = 112; // DDR中相对位置: 92 - // 剩余3个可用空间(DDR剩余2个, 相对位置:98 99; DDR淘汰列表1个) - table.maxOffset = 118; - table.evictHostPos.emplace_back(110); // 淘汰列表 - - TableInfo ti = table.GetTableInfo(); - - vector currentKeys = {6, 4, 55, 65, 75}; - auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, EVAL_CHANNEL_ID); - ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - LOG_INFO("test eval+space enough+externalSSDKeysEmpty ok."); - - // 评估+DDR剩余空间足够+externalSSDKeys非空 - vector currentKeys2 = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; - auto ret2 = cacheManager.TransferDDREmbWithSSD(ti, currentKeys2, EVAL_CHANNEL_ID); - ASSERT_EQ(ret2, TransferRet::TRANSFER_OK); - // 检查处理后数据正确性 - const auto& it = cacheManager.hostEmbs->hostEmbs.find(embTableName); - auto& hostData = it->second.embData; - ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); // DDR内未移动的数据 - ASSERT_TRUE(fabs(hostData[98][0] - 25.0f) < EPSILON); // SSD转移到DDR的数据 - ASSERT_TRUE(fabs(hostData[90][0] - 15.0f) < EPSILON); // SSD转移到DDR的数据 - ASSERT_EQ(table.evictHostPos.size(), 0); - // 原DDR中最小频次key(9,8)次数(1)被转移到SSD,SSD转移到DDR的key(15,25)次数(3,5), DDR内频次索引应变为2 - ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 9)); - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 8)); - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 15)); - LOG_INFO("test eval+space enough+externalSSDKeysNotEmpty ok."); -} - -TEST_F(CacheManagerTest, TransferDDREmbWithSSDByDDRSpaceNotEnough) -{ - // 构造DDR所有空间不满足存放当前批次数据 - EmbeddingDDR table; - table.name = embTableName2; - table.devVocabSize = 20; - table.hostVocabSize = 10; - table.maxOffset = 30; - table.keyOffsetMap[6] = 9; - table.keyOffsetMap[4] = 8; - - TableInfo ti = table.GetTableInfo(); - - // keys size:10, ddr keys:2 externalKeys:8 externalSSDKeys:0 - vector currentKeys = {6, 4, 101, 102, 103, 104, 105, 106, 107, 108}; - auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); - ASSERT_EQ(ret, TransferRet::DDR_SPACE_NOT_ENOUGH); - LOG_INFO("test train+ddr space enough+externalSSDKeysEmpty ok."); + ASSERT_FALSE(cacheManager.IsKeyInL3Storage(embTableName, checkKeys[0])); + ASSERT_FALSE(cacheManager.IsKeyInL3Storage(embTableName, checkKeys[1])); + ASSERT_TRUE(cacheManager.IsKeyInL3Storage(embTableName, checkKeys[2])); + ASSERT_TRUE(cacheManager.IsKeyInL3Storage(embTableName, checkKeys[3])); + LOG_INFO("test IsKeyInL3Storage end."); } -TEST_F(CacheManagerTest, EvictSSDEmbedding) +TEST_F(CacheManagerTest, EvictL3StorageEmbedding) { // 构造时ssd中已存在的key: 15 25 - emb_key_t key = 15; - vector ssdKeys = {key}; - cacheManager.EvictSSDEmbedding(embTableName, ssdKeys); - ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, key)); + emb_cache_key_t key = 15; + vector ssdKeys = {key}; + cacheManager.EvictL3StorageEmbedding(embTableName, ssdKeys); + int maxLoop = 1000; + while (!cacheManager.l3StorageEvictThreads.empty() && maxLoop > 0) { + this_thread::sleep_for(1ms); + maxLoop--; + } + ASSERT_FALSE(cacheManager.IsKeyInL3Storage(embTableName, key)); const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(key); ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); - LOG_INFO("test EvictSSDEmbedding end."); + LOG_INFO("test EvictL3StorageEmbedding end."); } TEST_F(CacheManagerTest, LoadTest) { - cacheManager.ddrKeyFreqMap.clear(); - cacheManager.excludeDDRKeyCountMap.clear(); - unordered_map> ddrMap; - string embTableName = "table1"; - unordered_map ddrTableMap; - ddrTableMap.emplace(1, 3); - ddrTableMap.emplace(2, 3); - ddrTableMap.emplace(3, 3); - ddrTableMap.emplace(4, 2); - ddrTableMap.emplace(6, 2); - ddrTableMap.emplace(8, 1); - ddrTableMap.emplace(9, 1); - ddrMap.emplace(embTableName, ddrTableMap); - unordered_map> excludeDdrMap; - unordered_map excludeDdrTableMap; - excludeDdrTableMap.emplace(15, 1); - excludeDdrTableMap.emplace(25, 5); - excludeDdrMap.emplace(embTableName, excludeDdrTableMap); - cacheManager.Load(ddrMap, excludeDdrMap, 0, 1, 0); - // 数据检查 - auto& ddrKeyFreqMap = cacheManager.ddrKeyFreqMap; - auto& excludeDDRKeyCountMap = cacheManager.excludeDDRKeyCountMap; - ASSERT_EQ(ddrKeyFreqMap[embTableName].minFreq, 1); - ASSERT_EQ(ddrKeyFreqMap[embTableName].freqTable.size(), 3); - ASSERT_EQ(ddrKeyFreqMap[embTableName].Get(2), 3); - ASSERT_EQ(ddrKeyFreqMap[embTableName].Get(12), -1); - ASSERT_EQ(excludeDDRKeyCountMap[embTableName][25], 5); } \ No newline at end of file diff --git a/src/tests/ssd_cache/lfu_cache_test.cpp b/src/tests/ssd_cache/lfu_cache_test.cpp index 1adf4aada771f4951275d7012b98d8856de0318a..500e398988ea701dbd045387534a016fc555bf60 100644 --- a/src/tests/ssd_cache/lfu_cache_test.cpp +++ b/src/tests/ssd_cache/lfu_cache_test.cpp @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and #include #include -#include "ssd_cache/lfu_cache.h" +#include "l3_storage/lfu_cache.h" using namespace std; using namespace MxRec; @@ -31,7 +31,7 @@ using namespace testing; */ vector INPUT_KEYS = {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 6, 6, 8, 9}; -inline void CompareHandleRet(vector& leastFreqKeys, vector& leastFreq, +inline void CompareHandleRet(vector& leastFreqKeys, vector& leastFreq, vector& expectKeys, vector& expectFreq) { @@ -81,8 +81,8 @@ TEST(LFUCache, PutInitTest) cache.PutWithInit(6, 2); cache.PutWithInit(8, 1); cache.PutWithInit(9, 1); - vector retainedKeys = {4, 6}; - vector leastFreqKeys; + vector retainedKeys = {4, 6}; + vector leastFreqKeys; vector leastFreq; cache.GetAndDeleteLeastFreqKeyInfo(2, retainedKeys, leastFreqKeys, leastFreq); vector expectKeys = {9, 8}; @@ -95,8 +95,8 @@ TEST(LFUCache, LFUDeleteTotalFreqListTest) { LFUCache cache; PutKeys(cache, INPUT_KEYS); - vector retainedKeys = {4, 6, 8, 9}; - vector leastFreqKeys; + vector retainedKeys = {4, 6, 8, 9}; + vector leastFreqKeys; vector leastFreq; cache.GetAndDeleteLeastFreqKeyInfo(2, retainedKeys, leastFreqKeys, leastFreq); vector expectKeys = {3, 2}; @@ -108,8 +108,8 @@ TEST(LFUCache, BaseCacheTest) { LFUCache cache; PutKeys(cache, INPUT_KEYS); - vector retainedKeys = {8, 4, 6, 2}; - vector leastFreqKeys; + vector retainedKeys = {8, 4, 6, 2}; + vector leastFreqKeys; vector leastFreq; cache.GetAndDeleteLeastFreqKeyInfo(2, retainedKeys, leastFreqKeys, leastFreq); vector expectKeys = {9, 3}; @@ -120,5 +120,5 @@ TEST(LFUCache, BaseCacheTest) cache.Put(9); ASSERT_EQ(cache.Get(9), 1); cache.Put(9); - ASSERT_EQ(cache.minFreq, 2); + ASSERT_EQ(cache.minFreq, 1); } diff --git a/src/tests/ssd_engine/engine_test.cpp b/src/tests/ssd_engine/engine_test.cpp index aad64a99485919891b8128968baa08746a735d8a..be57ad2f4914253d1bc13b4e55c9d80bf629c8fc 100644 --- a/src/tests/ssd_engine/engine_test.cpp +++ b/src/tests/ssd_engine/engine_test.cpp @@ -47,9 +47,9 @@ TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) ASSERT_EQ(eng->IsTableExist(tbName), true); // write - vector keys; + vector keys; vector> embeddings; - for (emb_key_t k = 0; k < 10; k++) { + for (emb_cache_key_t k = 0; k < 10; k++) { keys.emplace_back(k); vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; embeddings.emplace_back(emb); @@ -64,7 +64,7 @@ TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) ASSERT_EQ(eng->GetTableAvailableSpace(tbName), maxTableSize - keys.size()); // delete and wait auto compact - vector deleteKeys = {0}; + vector deleteKeys = {0}; eng->DeleteEmbeddings(tbName, deleteKeys); this_thread::sleep_for(compactPeriod); @@ -124,9 +124,9 @@ TEST(SSDEngine, LoadAndRead) engSave->CreateTable(tbName, savePath, maxTableSize); // write - vector keys; + vector keys; vector> embeddings; - for (emb_key_t k = 0; k < 10; k++) { + for (emb_cache_key_t k = 0; k < 10; k++) { keys.emplace_back(k); vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; embeddings.emplace_back(emb); @@ -141,7 +141,7 @@ TEST(SSDEngine, LoadAndRead) shared_ptr engLoad = make_shared(); engLoad->Start(); engLoad->Load(tbName, savePath, maxTableSize, saveStep); - for (emb_key_t k: keys) { + for (emb_cache_key_t k: keys) { ASSERT_EQ(engLoad->IsKeyExist(tbName, k), true); } auto ret = engLoad->FetchEmbeddings(tbName, keys); diff --git a/src/tests/ssd_engine/file_test.cpp b/src/tests/ssd_engine/file_test.cpp index 599b5975fbb68acfcad44492bb747597f0c5fb72..cdd80fc5d3c44b4378c3549026b6fa18449e9d96 100644 --- a/src/tests/ssd_engine/file_test.cpp +++ b/src/tests/ssd_engine/file_test.cpp @@ -100,9 +100,9 @@ TEST(File, WriteAndRead) string savePath = GlogConfig::gRankId; auto f = make_shared(0, savePath); - vector keys; + vector keys; vector> embeddings; - for (emb_key_t k = 0; k < 10; k++) { + for (emb_cache_key_t k = 0; k < 10; k++) { keys.emplace_back(k); vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; embeddings.emplace_back(emb); @@ -129,7 +129,7 @@ TEST(File, SaveAndLoad) string fileDir = GlogConfig::gRankId; auto fTmp = make_shared(0, fileDir); - vector key = {0}; + vector key = {0}; vector> expect = {{1.0, 1.1}}; fTmp->InsertEmbeddings(key, expect); string saveDir = fileDir; // for test convenience @@ -142,3 +142,40 @@ TEST(File, SaveAndLoad) fs::remove_all(fileDir); } + +TEST(File, WriteByAddrAndRead) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + GlogConfig::gRankId = to_string(rankId); + + string savePath = GlogConfig::gRankId; + auto f = make_shared(0, savePath); + + vector keys; + vector embeddings; + uint64_t extEmbeddingSize = 1; + for (emb_cache_key_t k = 0; k < 10; k++) { + keys.emplace_back(k); + float* emb = new float; + *emb = static_cast(k + 0.1); + embeddings.emplace_back(emb); + } + + f->InsertEmbeddingsByAddr(keys, embeddings, extEmbeddingSize); + auto ret = f->FetchEmbeddings(keys); + for (int i = 0; i < 10; i++) { + if (std::abs(ret[i][0] - *embeddings[i]) > std::numeric_limits::epsilon()) { + FAIL() << "embedding result not equal to input"; + } + } + + for (auto emb : embeddings) + { + delete emb; + emb = nullptr; + } + + + fs::remove_all(savePath); +} \ No newline at end of file diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index 2e180c1360de419ba1a48bda1f6006dc2956129e..20a66f2f75620599259653901dd7dd20f2a6d862 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -41,13 +41,13 @@ TEST(Table, WriteAndReadAndDeleteAndCompact) // write emb_key_t nData = 1000000; emb_key_t batchSize = 10000; - vector allKeys; + vector allKeys; vector> allEmbs; - vector batchKeys; + vector batchKeys; vector> batchEmbs; chrono::milliseconds writeCost = 0ms; - for (emb_key_t k = 0; k < nData; k++) { + for (emb_cache_key_t k = 0; k < nData; k++) { vector emb; emb.resize(embDim); for (uint64_t i = 0; i < embDim; ++i) { @@ -122,9 +122,9 @@ TEST(Table, SaveAndLoad) // write and save emb_key_t nData = 10; - vector keys; + vector keys; vector> embs; - for (emb_key_t k = 0; k < nData; k++) { + for (emb_cache_key_t k = 0; k < nData; k++) { vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; keys.emplace_back(k); embs.emplace_back(emb); @@ -160,7 +160,7 @@ TEST(Table, GetTableUsage) // write uint64_t expectKeyCnt = 2; - vector keys = {1, 2}; + vector keys = {1, 2}; vector> embs = {{0.1}, {0.2}}; tbSave->InsertEmbeddings(keys, embs); diff --git a/src/tests/utils/common_h_test.cpp b/src/tests/utils/common_h_test.cpp index 2e86b88de7a843a3283e39945bee5cf35c61135d..bf08919854517b0dabd9ef3465bf7990207ff5f9 100644 --- a/src/tests/utils/common_h_test.cpp +++ b/src/tests/utils/common_h_test.cpp @@ -113,12 +113,6 @@ TEST(TestHostEmbTable, DefaultConstructor) MxRec::HostEmbTable hostEmbTable; } -// 测试 EmbHashMapInfo 结构的默认构造函数 -TEST(TestEmbHashMapInfo, DefaultConstructor) -{ - MxRec::EmbHashMapInfo embHashMapInfo; -} - // 测试 All2AllInfo 结构的默认构造函数 TEST(TestAll2AllInfo, DefaultConstructor) { diff --git a/src/tests/utils/config_test.cpp b/src/tests/utils/config_test.cpp index d7e51b578be391b6077d62396a4da3969b0da973..54e0ec6756bedbdfe6c6d684a0181907742e5278 100644 --- a/src/tests/utils/config_test.cpp +++ b/src/tests/utils/config_test.cpp @@ -24,7 +24,6 @@ using namespace MxRec; void SetEnvironmentVariables() { - setenv(RecEnvNames::APPLY_GRADIENTS_STRATEGY, "sum_same_id_gradients_and_apply", 1); setenv(RecEnvNames::ACL_TIMEOUT, "100", 1); setenv(RecEnvNames::HD_CHANNEL_SIZE, "50", 1); setenv(RecEnvNames::KEY_PROCESS_THREAD_NUM, "8", 1); @@ -40,7 +39,6 @@ void SetEnvironmentVariables() void UnsetEnvironmentVariables() { - unsetenv(RecEnvNames::APPLY_GRADIENTS_STRATEGY); unsetenv(RecEnvNames::ACL_TIMEOUT); unsetenv(RecEnvNames::HD_CHANNEL_SIZE); unsetenv(RecEnvNames::KEY_PROCESS_THREAD_NUM); @@ -56,7 +54,6 @@ void UnsetEnvironmentVariables() TEST(GlobalEnv, DefaultValues) { - ASSERT_EQ(GlobalEnv::applyGradientsStrategy, ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY); ASSERT_EQ(GlobalEnv::aclTimeout, -1); ASSERT_EQ(GlobalEnv::hdChannelSize, 40); ASSERT_EQ(GlobalEnv::keyProcessThreadNum, 6); @@ -77,7 +74,6 @@ TEST(GlobalEnv, ConfigGlobalEnv) ConfigGlobalEnv(); // 验证环境变量是否已经被正确配置 - ASSERT_EQ(GlobalEnv::applyGradientsStrategy, "sum_same_id_gradients_and_apply"); ASSERT_EQ(GlobalEnv::aclTimeout, 100); ASSERT_EQ(GlobalEnv::hdChannelSize, 50); ASSERT_EQ(GlobalEnv::keyProcessThreadNum, 8); diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py index 7566aa1af336ac4c8650c1a7b5a13e2ff7fd81f0..e02f6257fdf3cca77783e95de78c3f08cc0a1ba8 100644 --- a/tests/mx_rec/core/mock_class.py +++ b/tests/mx_rec/core/mock_class.py @@ -20,8 +20,6 @@ import tensorflow as tf from tensorflow_core.python.training import slot_creator from mx_rec import ASCEND_GLOBAL_HASHTABLE_COLLECTION -from mx_rec.optimizers.lazy_adam import CustomizedLazyAdam -from mx_rec.util.config_utils.embedding_utils import SparseEmbedConfig from mx_rec.util.config_utils.feature_spec_utils import FeatureSpecConfig from mx_rec.util.config_utils.optimizer_utils import OptimizerConfig @@ -121,7 +119,6 @@ class MockConfigInitializer: def __init__(self, **kwargs): self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) self.use_static = kwargs.get("use_static", False) - self.use_hot = kwargs.get("use_static", True) self.modify_graph = kwargs.get("modify_graph", True) self.max_steps = kwargs.get("max_steps", -1) self.train_steps = kwargs.get("get_train_steps", -1) @@ -208,23 +205,7 @@ class MockOptimizer: def __init__(self): self.slot_num = 2 - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - return new_slot_variable - - momentum = creat_one_single_slot(var, self._name + "/" + "momentum") - velocity = creat_one_single_slot(var, self._name + "/" + "velocity") - named_slot_key = (var.op.graph, var.op.name) - - table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) - return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, - {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - pass + self.derivative = 2 def get_slot_init_values(self): initial_momentum_value = 0.0 diff --git a/tests/mx_rec/core/test_build_graph.py b/tests/mx_rec/core/test_build_graph.py index c15d851f25112106412c2e750c926001b975a855..5a24fd749e9d2f2d0b936ed69390ad51725e56ac 100644 --- a/tests/mx_rec/core/test_build_graph.py +++ b/tests/mx_rec/core/test_build_graph.py @@ -21,6 +21,7 @@ from unittest import mock import tensorflow as tf from mx_rec.util.global_env_conf import global_env +from mx_rec.core.asc.build_graph import SwapInfo from tests.mx_rec.core.mock_class import MockConfigInitializer @@ -33,13 +34,13 @@ class TestGetRestoreVectorFunc(unittest.TestCase): # 默认动态扩容、hot emb、HBM self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) def tearDown(self): # 恢复config self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) def test_get_restore_vector_case1(self): """ @@ -115,14 +116,14 @@ class TestGetIdOffsetsFunc(unittest.TestCase): # 默认动态扩容、hot emb、HBM self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") def tearDown(self): # 恢复config self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") def test_get_id_offsets_case1(self, mock_get_next): @@ -134,10 +135,12 @@ class TestGetIdOffsetsFunc(unittest.TestCase): with tf.Graph().as_default(): mock_get_next.return_value = [0] - id_offsets, swap_pos, swap_len = get_id_offsets(self.max_lookup_vec_size, self.config) + id_offsets, swap_info = get_id_offsets(self.max_lookup_vec_size, self.config) self.assertEqual(id_offsets, 0) - self.assertListEqual(swap_pos, []) - self.assertEqual(swap_len, 0) + self.assertListEqual(swap_info.swap_in_pos, []) + self.assertEqual(swap_info.swap_in_len, 0) + self.assertListEqual(swap_info.swap_out_pos, []) + self.assertEqual(swap_info.swap_out_len, 0) @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") def test_get_id_offsets_case2(self, mock_get_next): @@ -150,88 +153,12 @@ class TestGetIdOffsetsFunc(unittest.TestCase): with tf.Graph().as_default(): self.config["use_dynamic_expansion"] = False mock_get_next.return_value = [0] - id_offsets, swap_pos, swap_len = get_id_offsets(self.max_lookup_vec_size, self.config) + id_offsets, swap_info = get_id_offsets(self.max_lookup_vec_size, self.config) self.assertEqual(id_offsets, 0) - self.assertListEqual(swap_pos, []) - self.assertEqual(swap_len, 0) - - -class TestGetRestoreVectorSecondFunc(unittest.TestCase): - """ - Test for 'mx_rec.core.asc.build_graph.get_restore_vector_second'. - """ - - def setUp(self): - # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, - feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") - - def tearDown(self): - # 恢复config - self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, - feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_restore_vector_second(self, mock_get_next): - """ - case: test get_restore_vector_second - """ - - from mx_rec.core.asc.build_graph import get_restore_vector_second - - with tf.Graph().as_default(): - mock_get_next.return_value = [0] - restore_vector_second = get_restore_vector_second(self.max_lookup_vec_size, self.config) - self.assertEqual(restore_vector_second, 0) - - -class TestGetUniqueKeysFunc(unittest.TestCase): - """ - Test for 'mx_rec.core.asc.build_graph.get_unique_keys'. - """ - - def setUp(self): - # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, - feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") - - def tearDown(self): - # 恢复config - self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, - feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_unique_keys_case1(self, mock_get_next): - """ - case1: 动态扩容 - """ - - from mx_rec.core.asc.build_graph import get_unique_keys - - with tf.Graph().as_default(): - mock_get_next.return_value = [0] - unique_keys = get_unique_keys(self.max_lookup_vec_size, self.config) - self.assertEqual(unique_keys, 0) - - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_unique_keys_case2(self, mock_get_next): - """ - case2: 非动态扩容 - """ - - from mx_rec.core.asc.build_graph import get_unique_keys - - with tf.Graph().as_default(): - self.config["use_dynamic_expansion"] = False - mock_get_next.return_value = [1] - unique_keys = get_unique_keys(self.max_lookup_vec_size, self.config) - self.assertEqual(unique_keys, 1) + self.assertListEqual(swap_info.swap_in_pos, []) + self.assertEqual(swap_info.swap_in_len, 0) + self.assertListEqual(swap_info.swap_out_pos, []) + self.assertEqual(swap_info.swap_out_len, 0) class TestGetAll2allArgsFunc(unittest.TestCase): @@ -243,13 +170,13 @@ class TestGetAll2allArgsFunc(unittest.TestCase): # 默认动态扩容、hot emb、HBM self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) def tearDown(self): # 恢复config self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) def test_get_all2all_args_case1(self): """ @@ -276,60 +203,6 @@ class TestGetAll2allArgsFunc(unittest.TestCase): self.assertEqual(all2all_args, 0) -class TestGetSwapInfoFunc(unittest.TestCase): - """ - Test for 'mx_rec.core.asc.build_graph.get_swap_info'. - """ - - def setUp(self): - # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, - feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - - def tearDown(self): - # 恢复config - self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, - feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - - @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") - def test_get_swap_info_case1(self, build_graph_config_initializer): - """ - case1: 静态shape,HBM - """ - - from mx_rec.core.asc.build_graph import get_swap_info - - with tf.Graph().as_default(): - mock_config_initializer = MockConfigInitializer(use_static=True) - build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - - swap_in = get_swap_info(self.config, None, None, None) - self.assertIsInstance(swap_in[0], type(tf.no_op())) - - @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_swap_info_case2(self, mock_get_next, build_graph_config_initializer): - """ - case2: 静态shape,非HBM,table传入非list,抛出异常 - """ - - from mx_rec.core.asc.build_graph import get_swap_info - - with tf.Graph().as_default(): - mock_config_initializer = MockConfigInitializer(use_static=True) - build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - - mock_get_next.return_value = tf.ones(shape=[8, 8], dtype=tf.float32) - swap_pos = tf.constant([8, 9], dtype=tf.int32) - swap_len = tf.constant(2, dtype=tf.int32) - table = tf.compat.v1.get_variable("test_table", shape=[10, 8], initializer=tf.ones_initializer()) - self.config["is_hbm"] = False - with self.assertRaises(RuntimeError): - get_swap_info(self.config, swap_len, swap_pos, table) - - class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): """ Test for 'mx_rec.core.asc.build_graph.get_preprocessed_tensor_for_asc'. @@ -339,22 +212,18 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): # 默认动态扩容、hot emb、HBM self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) + use_dynamic_expansion=True) def tearDown(self): # 恢复config self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, - use_hot=True, use_dynamic_expansion=True) - global_env.apply_gradients_strategy = "direct_apply" + use_dynamic_expansion=True) @mock.patch.multiple("mx_rec.core.asc.build_graph", get_restore_vector=mock.MagicMock(return_value=[0, 0]), - get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), - get_all2all_args=mock.MagicMock(return_value=0), - get_swap_info=mock.MagicMock(return_value=0), - get_restore_vector_second=mock.MagicMock(return_value=0), - get_unique_keys=mock.MagicMock(return_value=0)) + get_id_offsets=mock.MagicMock(return_value=[0, SwapInfo()]), + get_all2all_args=mock.MagicMock(return_value=0)) @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") def test_get_preprocessed_tensor_for_asc_case1(self, build_graph_config_initializer): """ @@ -363,23 +232,17 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc - global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer(use_static=True) build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) result = get_preprocessed_tensor_for_asc(None, self.config) self.assertIsNotNone(result.get("restore_vector")) - self.assertIsNotNone(result.get("restore_vector_second")) - self.assertIsNotNone(result.get("unique_keys")) @mock.patch.multiple("mx_rec.core.asc.build_graph", get_restore_vector=mock.MagicMock(return_value=[0, 0]), - get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), - get_all2all_args=mock.MagicMock(return_value=0), - get_swap_info=mock.MagicMock(return_value=0), - get_restore_vector_second=mock.MagicMock(return_value=0), - get_unique_keys=mock.MagicMock(return_value=0)) + get_id_offsets=mock.MagicMock(return_value=[0, SwapInfo()]), + get_all2all_args=mock.MagicMock(return_value=0)) @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") def test_get_preprocessed_tensor_for_asc_case2(self, build_graph_config_initializer): """ @@ -388,23 +251,17 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc - global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer() build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) result = get_preprocessed_tensor_for_asc(None, self.config) self.assertIsNotNone(result.get("restore_vector")) - self.assertIsNotNone(result.get("restore_vector_second")) - self.assertIsNotNone(result.get("unique_keys")) @mock.patch.multiple("mx_rec.core.asc.build_graph", get_restore_vector=mock.MagicMock(return_value=[0, 0]), - get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), - get_all2all_args=mock.MagicMock(return_value=0), - get_swap_info=mock.MagicMock(return_value=0), - get_restore_vector_second=mock.MagicMock(return_value=0), - get_unique_keys=mock.MagicMock(return_value=0)) + get_id_offsets=mock.MagicMock(return_value=[0, SwapInfo]), + get_all2all_args=mock.MagicMock(return_value=0)) @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") def test_get_preprocessed_tensor_for_asc_case3(self, build_graph_config_initializer): """ @@ -413,7 +270,6 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc - global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer() build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) @@ -421,7 +277,6 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): self.config["channel_id"] = 1 result = get_preprocessed_tensor_for_asc(None, self.config) self.assertIsNotNone(result.get("restore_vector")) - self.assertIsNone(result.get("restore_vector_second")) if __name__ == '__main__': diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py index bf7d9240d3f740bdc4d5a99ba775dd747f29bde5..509b9ae7be51f16621f1550863318118eba67938 100644 --- a/tests/mx_rec/core/test_embedding.py +++ b/tests/mx_rec/core/test_embedding.py @@ -88,7 +88,8 @@ class TestCreateTableFunc(unittest.TestCase): test_table = create_table(key_dtype=tf.int64, dim=8, name='test_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer()) + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=8) self.assertIsInstance(test_table, HBMSparseEmbedding) @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", @@ -120,8 +121,8 @@ class TestCreateTableFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - host_vocabulary_size=8, - optimizer_list=[create_hash_optimizer(learning_rate=0.01)]) + device_vocabulary_size=8, + host_vocabulary_size=8) self.assertIsInstance(test_table, ExternalStorageSparseEmbedding) @@ -134,12 +135,11 @@ class TestSparseLookupFunc(unittest.TestCase): get_rank_size=mock.MagicMock(return_value=8), get_rank_id=mock.MagicMock(return_value=0), get_device_id=mock.MagicMock(return_value=0)) - @mock.patch("mx_rec.core.emb.sparse_embedding.get_preprocessed_tensor_for_asc") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.get_preprocessed_tensor_for_asc") @mock.patch("mx_rec.core.embedding.ConfigInitializer") @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") - @mock.patch("mx_rec.core.emb.sparse_embedding.ConfigInitializer") - def test_sparse_lookup_case1(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + def test_sparse_lookup_case1(self, base_sparse_embedding_config_initializer, emb_validator_config_initializer, sparse_embedding_config_initializer, mock_get_preprocessed_tensor_for_asc): """ @@ -154,7 +154,6 @@ class TestSparseLookupFunc(unittest.TestCase): # mock mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) - embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) @@ -166,12 +165,9 @@ class TestSparseLookupFunc(unittest.TestCase): batch = {"case1_feat": tf.ones(shape=[8, 8], dtype=tf.int64)} mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), - "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), - "unique_keys": tf.ones(shape=[8, ], dtype=tf.int64), "hot_pos": tf.ones(shape=[8, ], dtype=tf.int64), "id_offsets": tf.ones(shape=[8, ], dtype=tf.int64), - "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64), - "swap_in": [tf.no_op()] + "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64) } # test @@ -190,12 +186,11 @@ class TestSparseLookupFunc(unittest.TestCase): get_rank_id=mock.MagicMock(return_value=0), get_device_id=mock.MagicMock(return_value=0)) @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") - @mock.patch("mx_rec.core.emb.sparse_embedding.get_preprocessed_tensor_for_asc") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.get_preprocessed_tensor_for_asc") @mock.patch("mx_rec.core.embedding.ConfigInitializer") @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") - @mock.patch("mx_rec.core.emb.sparse_embedding.ConfigInitializer") - def test_sparse_lookup_case2(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + def test_sparse_lookup_case2(self, base_sparse_embedding_config_initializer, emb_validator_config_initializer, sparse_embedding_config_initializer, mock_get_preprocessed_tensor_for_asc, feature_spec_config_initializer): """ @@ -210,7 +205,6 @@ class TestSparseLookupFunc(unittest.TestCase): # mock mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) - embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) @@ -219,12 +213,9 @@ class TestSparseLookupFunc(unittest.TestCase): case2_feat = tf.ones(shape=[8, 8], dtype=tf.int64) mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), - "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), - "unique_keys": tf.ones(shape=[8, ], dtype=tf.int64), "hot_pos": tf.ones(shape=[8, ], dtype=tf.int64), "id_offsets": tf.ones(shape=[8, ], dtype=tf.int64), - "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64), - "swap_in": [tf.no_op()] + "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64) } # test diff --git a/tests/mx_rec/core/test_feature_process.py b/tests/mx_rec/core/test_feature_process.py index b8bb07429ade00e43dc43a4bd5f4e6e9c5f4bb95..787648f486264d49b7a6697eb9157f5f4a11551c 100644 --- a/tests/mx_rec/core/test_feature_process.py +++ b/tests/mx_rec/core/test_feature_process.py @@ -78,13 +78,13 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] - evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) + evict_hook = EvictHook(evict_enable=True, evict_time_interval=10) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: sess.graph._unsafe_unfinalize() sess.run(tf.compat.v1.global_variables_initializer()) # sleep 1s 等待淘汰时间evict_time_interval - time.sleep(1) + time.sleep(10) # 获取原variable,淘汰会发生在此session run之后 ori_variable = sess.run(test_table.variable) diff --git a/tests/mx_rec/core/test_manager.py b/tests/mx_rec/core/test_manager.py index 815ad843b40f8efa4dc4d3ec4e9593896f092add..70c2f1500b39d30b385893478c528dfdcb979df4 100644 --- a/tests/mx_rec/core/test_manager.py +++ b/tests/mx_rec/core/test_manager.py @@ -74,7 +74,6 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): mock_opt = MockOptimizer() manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt - test_table.optimizer_instance_list = [mock_opt] table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, []) @@ -100,7 +99,6 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): mock_opt = MockOptimizer() manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt - test_table.optimizer_instance_list = [mock_opt] table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, []) @@ -139,7 +137,6 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): mock_opt = MockOptimizer() manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt - test_table.optimizer_instance_list = [mock_opt] table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, ["test_table_info"]) @@ -338,7 +335,6 @@ class TestMatchedOptSlotInitializersFunc(unittest.TestCase): table_instance.ext_emb_size = 24 mock_opt = MockOptimizer() manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt - table_instance.optimizer_instance_list = [mock_opt] slot_initializers = matched_opt_slot_initializers(table_instance) self.assertListEqual(slot_initializers, ["slot_initializer", "slot_initializer"]) @@ -383,8 +379,8 @@ class TestInitializeEmbCacheFunc(unittest.TestCase): get_device_id=mock.MagicMock(return_value=0), get_rank_size=mock.MagicMock(return_value=0), USE_STATIC=mock.MagicMock(return_value=0), - USE_HOT=mock.MagicMock(return_value=1), USE_DYNAMIC_EXPANSION=mock.MagicMock(return_value=2), + USE_SUM_SAME_ID_GRADIENTS=mock.MagicMock(return_value=4), RankInfo=mock.MagicMock(return_value="mock_info"), HybridMgmt=mock.MagicMock(return_value=MockHybridMgmt(is_initialized=False))) @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") @@ -398,6 +394,9 @@ class TestInitializeEmbCacheFunc(unittest.TestCase): mock_config_initializer = MockConfigInitializer(use_static=True, use_dynamic_expansion=True) manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_opt = MockOptimizer() + manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + with self.assertRaises(RuntimeError): initialize_emb_cache([], []) @@ -406,8 +405,8 @@ class TestInitializeEmbCacheFunc(unittest.TestCase): get_device_id=mock.MagicMock(return_value=0), get_rank_size=mock.MagicMock(return_value=0), USE_STATIC=mock.MagicMock(return_value=0), - USE_HOT=mock.MagicMock(return_value=1), USE_DYNAMIC_EXPANSION=mock.MagicMock(return_value=2), + USE_SUM_SAME_ID_GRADIENTS=mock.MagicMock(return_value=4), RankInfo=mock.MagicMock(return_value="mock_info")) @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") @mock.patch("mx_rec.core.asc.manager.HybridMgmt") @@ -421,6 +420,9 @@ class TestInitializeEmbCacheFunc(unittest.TestCase): mock_config_initializer = MockConfigInitializer(use_static=True, use_dynamic_expansion=True) manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_opt = MockOptimizer() + manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + mock_mgmt = MockHybridMgmt(is_initialized=True) mock_hybrid_mgmt.return_value = mock_mgmt initialize_emb_cache([], []) diff --git a/tests/mx_rec/graph/test_acg_push_ops.py b/tests/mx_rec/graph/test_acg_push_ops.py deleted file mode 100644 index 129b773f8145a6190a11d15371b5e2bd5bf7cfd4..0000000000000000000000000000000000000000 --- a/tests/mx_rec/graph/test_acg_push_ops.py +++ /dev/null @@ -1,514 +0,0 @@ -#!/usr/bin/env python3 -# coding: UTF-8 -# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from unittest import TestCase -from unittest.mock import patch, Mock - -import tensorflow as tf -from tensorflow.core.framework import node_def_pb2 -from tensorflow.python.data.ops.dataset_ops import DatasetV1 -from mx_rec.graph.acg_push_ops import ( - ACGPushOpsToDatasetHook, - SubgraphInfo, - _OP_NAME_CONTAIN_STRING_TO_PUSH, - _ACG_NEW_INITIALIZER, - _find_ops_to_be_pushed, - _find_op_from_base_op, - _find_subgraph_nodes, - _get_mapping_tensor, - _topo_subgraph, - _get_dataset_op, - _clone_subgraph_into_funcgraph, - _update_subgraph_out_consumer, - _get_src_dataset, - _update_iterator_getnext, - _find_subgraph_in_out, - _push_subgraph_to_dataset, - _warn_for_var_scope_nodes, - _frozen_variable_node_to_func_const_node_def, - _update_old_consumer, - _get_mapping_for_subgraph, - _get_mapping_for_subgraph_in, - _ordered_output_from_subgraph, - _replace_get_next_op, - _patched_get_src_dataset, -) -from tests.mx_rec.core.mock_class import MockConfigInitializer -from tests.mx_rec.graph.mock_dataset import gen_mock_dataset - - -@patch.multiple( - "mx_rec.graph.patch", - ConfigInitializer=Mock(return_value=MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True)), -) -@patch.multiple( - "tensorflow.compat.v1.train.Saver", - __init__=Mock(return_value=None), - build=Mock(), -) -@patch.multiple("mx_rec.graph.acg_push_ops", _find_ops_to_be_pushed=Mock()) -class ACGPushOpsToDatasetHookTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_cutting_point = tf.identity(mock_ids) - - mock_new_iterator = mock_dataset.make_initializable_iterator() - tf.compat.v1.add_to_collection(_ACG_NEW_INITIALIZER, mock_new_iterator.initializer) - - with tf.compat.v1.train.MonitoredSession(hooks=[ACGPushOpsToDatasetHook()]) as sess: - sess.run(mock_iterator.initializer) - sess.run(mock_cutting_point) - - -@patch.multiple( - "mx_rec.graph.acg_push_ops", - _find_subgraph_nodes=Mock(return_value=set()), - _push_subgraph_to_dataset=Mock(), -) -class FindOpsToBePushedTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok_op_contain_str_to_push(self): - tensor = tf.constant(value=[1, 2, 3], name="MOCK" + list(_OP_NAME_CONTAIN_STRING_TO_PUSH)[0]) - mock_graph = tf.compat.v1.get_default_graph() - _find_ops_to_be_pushed(mock_graph) - - def test_ok_op_type_to_push(self): - const_tensor = tf.constant(value=[1, 2, 3], dtype=tf.int32) - str_tensor = tf.compat.v1.as_string(const_tensor) - num_tensor = tf.compat.v1.string_to_number(str_tensor) - mock_graph = tf.compat.v1.get_default_graph() - _find_ops_to_be_pushed(mock_graph) - - def test_ok_no_node_to_push(self): - mock_graph = tf.compat.v1.get_default_graph() - _find_ops_to_be_pushed(mock_graph) - - -class FindSubgraphNodesTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - - tensor_in_subgraph = tf.identity(mock_ids) - tensor_out_subgraph = tf.identity(tensor_in_subgraph) - mock_base_nodes = {tensor_out_subgraph.op} - - subgraph_nodes = _find_subgraph_nodes( - tf.compat.v1.get_default_graph(), mock_base_nodes, tgt_op_type="IteratorGetNext" - ) - self.assertEqual(subgraph_nodes, {tensor_in_subgraph.op, tensor_out_subgraph.op}) - - -class WarnForVarScopeNodesTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - with tf.compat.v1.variable_scope("mock_var_scope"): - var1 = tf.compat.v1.get_variable("var", shape=(3, 3), initializer=tf.random_normal_initializer()) - - mock_all_nodes = tf.compat.v1.get_default_graph().get_operations() - mock_base_node = var1.op - _warn_for_var_scope_nodes(mock_all_nodes, mock_base_node) - - -class FindOpFromBaseOpTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_err_no_tgt_op_type(self): - parent_tensor = tf.ones(shape=(3, 3)) - child_tensor = tf.identity(parent_tensor) - with self.assertRaises(ValueError): - _find_op_from_base_op(child_tensor.op, "IteratorGetNext") - - -class GetDatasetOpTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_prefetch_dataset = mock_dataset.prefetch(buffer_size=10) - mock_iterator = mock_prefetch_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - mock_graph = tf.compat.v1.get_default_graph() - expected = mock_graph.get_operation_by_name("OptimizeDataset") - - tgt_dataset_op = _get_dataset_op(mock_graph, mock_get_next_op) - self.assertEqual(tgt_dataset_op, expected) - - def test_err_invalid_get_next_op_type(self): - mock_get_next_op = tf.zeros(shape=(3,)).op - mock_graph = tf.compat.v1.get_default_graph() - - with self.assertRaises(TypeError): - _get_dataset_op(mock_graph, mock_get_next_op) - - @patch.multiple("mx_rec.graph.acg_push_ops", _find_op_from_base_op=Mock(return_value=None)) - @patch.multiple("mx_rec.graph.acg_push_ops.modifier", find_parent_op=Mock(return_value=None)) - def test_err_no_tgt_op_found(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - mock_graph = tf.compat.v1.get_default_graph() - - with self.assertRaises(RuntimeError): - _get_dataset_op(mock_graph, mock_get_next_op) - - -class OrderedOutputFromSubgraphTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next(name="IteratorGetNext") - mock_ids = mock_batch.get("mock_ids") - - mock_subgraph_out = {tf.identity(mock_ids).op: {mock_ids.op}} - - addition_funcgraph_output_tensor = _ordered_output_from_subgraph(mock_subgraph_out) - self.assertEqual(addition_funcgraph_output_tensor, [mock_ids]) - - -class PushSubgraphToDatasetTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - - tensor_in_subgraph = tf.identity(mock_ids) - tensor_out_subgraph = tf.identity(tensor_in_subgraph) - mock_subgraph_to_push = {tensor_in_subgraph.op} - _push_subgraph_to_dataset(tf.compat.v1.get_default_graph(), mock_subgraph_to_push) - - -class FindSubgraphInOutTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - - tensor_in_subgraph = tf.identity(mock_ids) - tensor_out_subgraph = tf.identity(tensor_in_subgraph) - mock_subgraph_nodes = {tensor_in_subgraph.op} - - ( - subgraph_in, - subgraph_out, - ) = _find_subgraph_in_out(mock_subgraph_nodes) - self.assertEqual(subgraph_in, {mock_ids.op: {tensor_in_subgraph.op}}) - self.assertEqual(subgraph_out, {tensor_out_subgraph.op: {tensor_in_subgraph.op}}) - - -class GetSrcDatasetTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok_make_iterator(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - src_dataset = _get_src_dataset(tf.compat.v1.get_default_graph(), mock_get_next_op) - self.assertEqual(src_dataset, mock_dataset) - - def test_ok_one_shot_iterator(self): - mock_dataset = gen_mock_dataset() - mock_prefetch_dataset = mock_dataset.prefetch(10) - mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - src_dataset = _get_src_dataset(tf.compat.v1.get_default_graph(), mock_get_next_op) - self.assertEqual(src_dataset, mock_dataset) - - def test_err_no_anchor_dataset(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_one_shot_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - with self.assertRaises(RuntimeError): - _get_src_dataset(tf.compat.v1.get_default_graph(), mock_get_next_op) - - -class CloneSubgraphIntoFuncgraphTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - - mock_subgraph_in = {mock_ids.op: {tf.identity(mock_ids).op}} - mock_subgraph_out = {tf.identity(mock_ids).op: {mock_ids.op}} - mock_subgraph_to_push = set() - mock_subgraph_info = SubgraphInfo(mock_subgraph_in, mock_subgraph_out, mock_subgraph_to_push) - - mock_new_ids = tf.ones_like(mock_ids) - mock_x = [mock_new_ids] - mock_old_x = ({"mock_new_ids": mock_new_ids},) - - mock_defaultgraph = tf.compat.v1.get_default_graph() - with tf.Graph().as_default(): - mock_funcgraph = tf.compat.v1.get_default_graph() - _clone_subgraph_into_funcgraph(mock_funcgraph, mock_defaultgraph, mock_subgraph_info, mock_x, mock_old_x) - - -class GetMappingForSubgraphInTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_prefetch_dataset = mock_dataset.prefetch(10) - mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - - mock_from_node = mock_ids.op - mock_to_nodes = {tf.identity(mock_ids).op} - mock_new_ids = tf.zeros_like(mock_ids) - mock_x = [mock_new_ids] - tensor_mapping = dict() - - _get_mapping_for_subgraph_in(mock_from_node, mock_to_nodes, mock_x, tensor_mapping) - self.assertEqual(tensor_mapping, {mock_ids: mock_new_ids}) - - -class GetMappingForSubgraphTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_defaultgraph = tf.compat.v1.get_default_graph() - - # NOTE: Simulate independent graph environment while executing `dataset.map()` method. - with tf.Graph().as_default(): - key_tensor = tf.zeros(shape=(1)) - val_tensor = tf.zeros(shape=(1)) - mock_tensor_mapping = {key_tensor: val_tensor} - - mock_node_mapping = dict() - mock_old_node = tf.identity(key_tensor).op - mock_funcgraph = tf.compat.v1.get_default_graph() - - _get_mapping_for_subgraph( - mock_funcgraph, mock_defaultgraph, mock_node_mapping, mock_old_node, mock_tensor_mapping - ) - - self.assertEqual(len(mock_node_mapping), 1) - self.assertEqual(len(mock_tensor_mapping), 2) - - -@patch.multiple( - "mx_rec.graph.patch", - ConfigInitializer=Mock(return_value=MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True)), -) -class FrozenVariableNodeToFuncConstNodeDefTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - var_tensor = tf.Variable(initial_value=[1], shape=(1,)) - tf.compat.v1.assign(ref=var_tensor, value=[1]) - - mock_funcgraph = tf.Graph() - mock_defaultgraph = tf.compat.v1.get_default_graph() - new_const_node: node_def_pb2.NodeDef = _frozen_variable_node_to_func_const_node_def( - var_tensor.op, mock_funcgraph, mock_defaultgraph - ) - self.assertEqual(new_const_node.op, "Const") - - -class GetMappingTensorTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - key_tensor = tf.zeros(shape=(3, 3)) - val_tensor = tf.ones(shape=(3, 3)) - tsr2tsr = {key_tensor: val_tensor} - keys = [key_tensor] - - mapped_tensors = _get_mapping_tensor(tsr2tsr, keys) - self.assertEqual(mapped_tensors, [val_tensor]) - - def test_err_key_tensor_not_exist(self): - tsr2tsr = {tf.zeros(shape=(3, 3)): tf.ones(shape=(3, 3))} - keys = [tf.ones(shape=(3, 3))] - - with self.assertRaises(KeyError): - _get_mapping_tensor(tsr2tsr, keys) - - -class TopoSubgraphTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_prefetch_dataset = mock_dataset.prefetch(10) - mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - tensor1 = tf.identity(mock_ids) - tensor2 = tf.add(tensor1, 1) - mock_subgraph = {tensor1.op, tensor2.op} - - const_op_for_add = None - for tensor in tensor2.op.inputs: - if tensor.op.name != "Add/y": - continue - const_op_for_add = tensor.op - - if not const_op_for_add: - self.fail( - f"Failed to find input of add operation, input tensor of add op: {[x.op for x in tensor2.op.inputs]}" - ) - - topo_subgraph_list = _topo_subgraph(mock_subgraph) - self.assertEqual(topo_subgraph_list, [tensor1.op, const_op_for_add, tensor2.op]) - - -class UpdateIteratorGetNextTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_old_dataset = gen_mock_dataset() - mock_old_iterator = mock_old_dataset.make_initializable_iterator() - mock_old_batch = mock_old_iterator.get_next(name="OldIteratorGetNext") - mock_old_ids = mock_old_batch.get("mock_ids") - mock_old_get_next_op = mock_old_ids.op - - mock_new_dataset: DatasetV1 = mock_old_dataset.map(lambda x: x) - mock_subgraph_out = {tf.identity(mock_old_ids).op: {mock_old_ids.op}} - - _update_iterator_getnext( - graph=tf.compat.v1.get_default_graph(), - get_next_op=mock_old_get_next_op, - tgt_dataset=mock_new_dataset, - subgraph_out=mock_subgraph_out, - subgraph_to_push=set(), - ) - - -class UpdateOldConsumerTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next(name="NewIteratorGetNext") - mock_ids = mock_batch.get("mock_ids") - mock_new_get_next_op = mock_ids.op - mock_output_tensor = tf.identity(mock_ids) - - _update_old_consumer( - graph=tf.compat.v1.get_default_graph(), - new_get_next_op=mock_new_get_next_op, - output_tensor=mock_ids, - subgraph_to_push=set(), - ) - - -class UpdateSubgraphOutConsumerTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next(name="NewIteratorGetNext") - mock_ids = mock_batch.get("mock_ids") - mock_new_get_next_op = mock_ids.op - mock_output_tensor = tf.identity(mock_ids) - - _update_subgraph_out_consumer( - graph=tf.compat.v1.get_default_graph(), - new_get_next_op=mock_new_get_next_op, - offset=0, - output_tensor=mock_ids, - ) - - -class PatchedGetSrcDatasetTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_prefetch_dataset = mock_dataset.prefetch(10) - mock_double_prefetch_dataset = mock_prefetch_dataset.prefetch(10) - mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - src_dataset = _patched_get_src_dataset(mock_get_next_op, is_training=True) - self.assertEqual(src_dataset, mock_prefetch_dataset) - - def test_err_single_prefetch_dataset(self): - mock_dataset = gen_mock_dataset() - mock_prefetch_dataset = mock_dataset.prefetch(10) - mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - with self.assertRaises(RuntimeError): - _patched_get_src_dataset(mock_get_next_op, is_training=True) diff --git a/tests/mx_rec/graph/test_modifier.py b/tests/mx_rec/graph/test_modifier.py index 14b87617ff739966ac1dc3bef736131724f1d0a8..25caf4296c7e28eeac27c60b216bee0d1af3689b 100644 --- a/tests/mx_rec/graph/test_modifier.py +++ b/tests/mx_rec/graph/test_modifier.py @@ -18,7 +18,7 @@ import unittest from collections import defaultdict from unittest import TestCase -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, MagicMock from typing import Union, Callable import tensorflow as tf @@ -30,24 +30,15 @@ from mx_rec.constants.constants import ( ASCEND_TIMESTAMP, ASCAnchorAttr, ) -from mx_rec.core.asc import FeatureSpec -from mx_rec.graph.graph_typing import AnchorRecord from mx_rec.graph.modifier import ( GraphModifierHook, - find_make_iterator_op, - find_target_dataset_op, - find_target_instance_dataset, - generate_get_next_op_specs, - get_dataset_op, - get_input_index_list, - get_passing_tensor_list, - get_preprocessing_map_func, - get_src_dataset, - get_tgt_dataset, - get_timestamp_index, - modify_graph_for_asc, + _GraphModifier, + _AnchorRecord, + _get_input_index_list, + _get_passing_tensor_list, + _get_timestamp_index, ) -from tests.mx_rec.core.mock_class import MockConfigInitializer +from tests.mx_rec.core.mock_class import MockConfigInitializer, MockSparseEmbedding, MockOptimizer from tests.mx_rec.graph.mock_dataset import gen_mock_dataset @@ -70,16 +61,19 @@ def _gen_mock_get_anchor_attribute(is_training: bool = True) -> Callable: class GetPreprocessingMapFuncTest(TestCase): + def setUp(self) -> None: + self._modifier = _GraphModifier() + def tearDown(self) -> None: tf.compat.v1.reset_default_graph() def test_err_none_names_and_indexes(self): - mock_graph_def = tf.compat.v1.GraphDef() + mock_graph_def = self._modifier._full_graph.as_graph_def() mock_input_names = [] mock_output_names = [] with self.assertRaises(ValueError): - get_preprocessing_map_func(mock_graph_def, mock_input_names, mock_output_names) + _GraphModifier._get_preprocessing_map_func(mock_graph_def, mock_input_names, mock_output_names) class GetInputIndexListTest(TestCase): @@ -93,70 +87,11 @@ class GetInputIndexListTest(TestCase): mock_base_count = 0 with self.assertRaises(ValueError): - get_input_index_list( + _get_input_index_list( mock_cutting_point_list, mock_replace_ment_specs, mock_mapping_name_list, mock_base_count ) -class FindMakeIteratorOpTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - - found_iter_op = find_make_iterator_op(mock_ids) - self.assertEqual(found_iter_op.type, "MakeIterator") - - def test_err_no_tgt_dataset_op(self): - mock_ids = tf.zeros(shape=(4096, 8)) - with self.assertRaises(ValueError): - find_make_iterator_op(mock_ids) - - -class FindTargetDatasetOpTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_base_op = tf.identity(mock_ids).op - - found_tgt_dataset_op = find_target_dataset_op(base_ops=mock_base_op, op_type="IteratorGetNext") - self.assertEqual(found_tgt_dataset_op, mock_ids.op) - - def test_err_no_tgt_op_type(self): - mock_ids = tf.zeros(shape=(4096, 8)) - mock_base_op = mock_ids.op - with self.assertRaises(ValueError): - find_target_dataset_op(mock_base_op, "IteratorGetNext") - - -class GetDatasetOpTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_get_next_op = mock_ids.op - - found_dataset_op = get_dataset_op(mock_get_next_op) - self.assertEqual(found_dataset_op.type, "OptimizeDataset") - - def test_err_invalid_op_type(self): - mock_get_next_op = tf.zeros(shape=(4096, 8)).op - with self.assertRaises(TypeError): - get_dataset_op(mock_get_next_op) - class GetPassingTensorList(TestCase): def tearDown(self) -> None: @@ -176,7 +111,7 @@ class GetPassingTensorList(TestCase): "output_index_list": [0], "sub_src_tensors": mock_cutting_point_list, } - passing_tensor_list, output_index_list, sub_src_tensors = get_passing_tensor_list( + passing_tensor_list, output_index_list, sub_src_tensors = _get_passing_tensor_list( mock_cutting_point_list, mock_tgt_op ) self.assertEqual(passing_tensor_list, expected["passing_tensor_list"]) @@ -184,29 +119,23 @@ class GetPassingTensorList(TestCase): self.assertEqual(sub_src_tensors, expected["sub_src_tensors"]) -class FindTargetInstanceDatasetTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - def test_err_no_target_dataset_instance(self): - with self.assertRaises(LookupError): - find_target_instance_dataset(None) - class GetSrcDatasetTest(TestCase): + def setUp(self) -> None: + self._modifier = _GraphModifier() + def tearDown(self) -> None: tf.compat.v1.reset_default_graph() def test_ok_one_shot(self): mock_dataset = gen_mock_dataset() mock_prefetch_dataset = mock_dataset.prefetch(10) - mock_double_prefetch_dataset = mock_prefetch_dataset.prefetch(10) mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() mock_batch = mock_iterator.get_next() mock_ids = mock_batch.get("mock_ids") mock_get_next_op = mock_ids.op - src_dataset = get_src_dataset(mock_get_next_op, is_training=True) + src_dataset = self._modifier._get_src_dataset(mock_get_next_op, is_training=True) self.assertEqual(src_dataset, mock_dataset) @@ -215,6 +144,9 @@ class GetSrcDatasetTest(TestCase): ConfigInitializer=Mock(return_value=MockConfigInitializer()), ) class GetTgtDatasetTest(TestCase): + def setUp(self) -> None: + self._modifier = _GraphModifier() + def tearDown(self) -> None: tf.compat.v1.reset_default_graph() @@ -233,22 +165,16 @@ class GetTgtDatasetTest(TestCase): mock_batch = mock_iterator.get_next() mock_ids = mock_batch.get("mock_ids") mock_sub_cutting_point_list = [mock_ids] - mock_records = AnchorRecord( - defaultdict(), - [], - [], - [], - tf.compat.v1.GraphDef(), - [], - [], - True - ) + mock_records = _AnchorRecord(defaultdict(), [], [], [], tf.compat.v1.GraphDef(), [], [], True) - tgt_dataset = get_tgt_dataset(mock_dataset, mock_sub_cutting_point_list, mock_records) + tgt_dataset = self._modifier._get_tgt_dataset(mock_dataset, mock_sub_cutting_point_list, mock_records) self.assertIsNotNone(tgt_dataset) class ModifyGraphForAscTest(TestCase): + def setUp(self) -> None: + self._modifier = _GraphModifier() + def tearDown(self) -> None: tf.compat.v1.reset_default_graph() @@ -257,6 +183,11 @@ class ModifyGraphForAscTest(TestCase): get_asc_insert_func=Mock(return_value=lambda x, y: x), ) @patch.multiple("mx_rec.graph.modifier.BaseSparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + @patch.multiple( + "mx_rec.core.asc.manager", + should_skip=MagicMock(return_value=True), + check_dangling_table=MagicMock(return_value=["test_table"]), + ) @patch("mx_rec.graph.modifier.ConfigInitializer") def test_ok_train_mode(self, modifier_config_initializer): mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=True) @@ -268,9 +199,16 @@ class ModifyGraphForAscTest(TestCase): mock_ids = mock_batch.get("mock_ids") mock_cutting_point = tf.identity(mock_ids) + test_table = MockSparseEmbedding("test_table") + test_table.is_hbm = True + mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) + + mock_opt = MockOptimizer() + modifier_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) - modify_graph_for_asc() + self._modifier.modify_graph_for_asc() @patch.multiple( "mx_rec.graph.modifier", @@ -279,12 +217,13 @@ class ModifyGraphForAscTest(TestCase): ) @patch.multiple( "mx_rec.graph.modifier.BaseSparseEmbedding", - get_anchor_attribute=_gen_mock_get_anchor_attribute(is_training=False) + get_anchor_attribute=_gen_mock_get_anchor_attribute(is_training=False), ) @patch("mx_rec.graph.modifier.ConfigInitializer") def test_ok_eval_mode(self, modifier_config_initializer): - mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=True, - bool_gauge_set={"evaluate"}) + mock_config_initializer = MockConfigInitializer( + modify_graph=True, merged_multi_lookup=True, bool_gauge_set={"evaluate"} + ) modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) mock_dataset = gen_mock_dataset() @@ -293,9 +232,16 @@ class ModifyGraphForAscTest(TestCase): mock_ids = mock_batch.get("mock_ids") mock_cutting_point = tf.identity(mock_ids) + test_table = MockSparseEmbedding("test_table") + test_table.is_hbm = True + mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) + + mock_opt = MockOptimizer() + modifier_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) - modify_graph_for_asc() + self._modifier.modify_graph_for_asc() @patch.multiple( "mx_rec.graph.modifier", @@ -316,10 +262,13 @@ class ModifyGraphForAscTest(TestCase): tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) with self.assertRaises(RuntimeError): - modify_graph_for_asc() + self._modifier.modify_graph_for_asc() class GetTimestampIndexTest(TestCase): + def setUp(self) -> None: + self._graph = tf.compat.v1.get_default_graph() + def tearDown(self) -> None: tf.compat.v1.reset_default_graph() @@ -341,7 +290,7 @@ class GetTimestampIndexTest(TestCase): tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, mock_timestamp) - timestamp_index = get_timestamp_index(mock_get_next_op, is_training=True) + timestamp_index = _get_timestamp_index(self._graph, mock_get_next_op, is_training=True) self.assertEqual(timestamp_index, 2) @@ -365,8 +314,9 @@ class GraphModifierHookTest(TestCase): ) @patch("mx_rec.graph.modifier.ConfigInitializer") def test_ok(self, modifier_config_initializer): - mock_config_initializer = MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True, - iterator_type="MakeIterator") + mock_config_initializer = MockConfigInitializer( + modify_graph=True, is_graph_modify_hook_running=True, iterator_type="MakeIterator" + ) modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) mock_dataset = gen_mock_dataset() @@ -389,8 +339,9 @@ class GraphModifierHookTest(TestCase): ) @patch("mx_rec.graph.modifier.ConfigInitializer") def test_err_invalid_iterator_type(self, modifier_config_initializer): - mock_config_initializer = MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True, - iterator_type="InvalidIterator") + mock_config_initializer = MockConfigInitializer( + modify_graph=True, is_graph_modify_hook_running=True, iterator_type="InvalidIterator" + ) modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) mock_dataset = gen_mock_dataset() diff --git a/tests/mx_rec/graph/test_slicers.py b/tests/mx_rec/graph/test_slicers.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d9cad978d4e8f95e4010974a005b3fd5ecf1aa --- /dev/null +++ b/tests/mx_rec/graph/test_slicers.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest +from unittest.mock import patch, Mock + +import tensorflow as tf +from tensorflow import Graph + +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE +from mx_rec.graph.constants import AnchorDatasetOp +from mx_rec.graph.slicers import NoGradSubgraphSlicer, LookupSubgraphSlicer, OrphanLookupKeySlicer +from tests.mx_rec.graph.mock_dataset import gen_mock_dataset + + +class MockNoGradSubgraphSlicer(NoGradSubgraphSlicer): + def __init__(self, full_graph: Graph = None, info_dir: str = "slicing") -> None: + super().__init__(full_graph, info_dir) + + def summarize(self) -> None: + pass + + def slice(self) -> None: + pass + + +class NoGradSubgraphSlicerTestCase(unittest.TestCase): + def test_ok_slice_ops(self): + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + prefetch_dataset = dataset.prefetch(0) + + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + + mock_ids = batch["mock_ids"] + mock_labels = batch["mock_labels"] + + inner_tensor = tf.identity(mock_ids) + inner_op = inner_tensor.op + + tf.identity(inner_tensor) + tf.identity(mock_labels) + + sliced_ops = {inner_op} + MockNoGradSubgraphSlicer()._slice_ops(sliced_ops, is_training=True) + + g = tf.compat.v1.get_default_graph() + prefetch_datasets = [op for op in g.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name] + self.assertEqual(len(prefetch_datasets), 2) + + def test_ok_find_min_dep_ops(self): + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + ids = batch["mock_ids"] + + subgraph_in = tf.identity(ids) + subgraph_out = tf.identity(subgraph_in) + base_ops = {subgraph_out.op} + + min_dep_ops = NoGradSubgraphSlicer._find_min_dep_ops(base_ops) + self.assertEqual(min_dep_ops, {subgraph_in.op, subgraph_out.op}) + + def test_ok_validate_op(self): + with tf.compat.v1.Graph().as_default(): + t = tf.constant(0) + t = tf.add(t, 1) + t = tf.subtract(t, 1) + op = t.op + + is_valid = NoGradSubgraphSlicer._validate_op(op) + self.assertTrue(is_valid, True) + + def test_ok_find_subgraph_in_and_out(self): + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + ids = batch.get("mock_ids") + + input_tensor = tf.identity(ids) + inner_tensor = tf.identity(input_tensor) + output_tensor = tf.identity(inner_tensor) + subgraph_ops = {inner_tensor.op} + + (subgraph_in, subgraph_out) = MockNoGradSubgraphSlicer()._find_subgraph_in_and_out(subgraph_ops) + self.assertEqual(subgraph_in, {input_tensor.op: {inner_tensor.op}}) + self.assertEqual(subgraph_out, {output_tensor.op: {inner_tensor.op}}) + + def test_ok_find_old_dataset(self): + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + iterator = tf.compat.v1.data.make_initializable_iterator(dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + get_next = ids.op + + old_dataset = MockNoGradSubgraphSlicer()._find_old_dataset(get_next, is_training=True) + self.assertEqual(old_dataset, dataset) + + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + prefetch_dataset = dataset.prefetch(0) + iterator = tf.compat.v1.data.make_one_shot_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + get_next = ids.op + + old_dataset = MockNoGradSubgraphSlicer()._find_old_dataset(get_next, is_training=True) + self.assertEqual(old_dataset, dataset) + + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + prefetch_dataset = dataset.prefetch(0) + gen_mock_dataset().prefetch(0) + + iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + get_next = ids.op + + old_dataset = MockNoGradSubgraphSlicer()._find_old_dataset(get_next, is_training=True) + self.assertEqual(old_dataset, dataset) + + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + prefetch_dataset = dataset.prefetch(0) + gen_mock_dataset().prefetch(0) + + iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + get_next = ids.op + + old_dataset = MockNoGradSubgraphSlicer()._find_old_dataset(get_next, is_training=False) + self.assertEqual(old_dataset, dataset) + + def test_ok_make_new_dataset(self): + with tf.compat.v1.Graph().as_default(): + dataset = gen_mock_dataset() + prefetch_dataset = dataset.prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + in_op = ids.op + inner_tensor = tf.identity(ids) + inner_op = inner_tensor.op + out_op = tf.identity(inner_tensor).op + + sliced_ops = {inner_op} + in_op_to_edge_ops = {in_op: {inner_op}} + out_op_to_edge_ops = {out_op: {inner_op}} + + new_dataset = MockNoGradSubgraphSlicer()._make_new_dataset( + dataset, sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops + ) + new_prefetch_dataset = new_dataset + new_iter = tf.compat.v1.data.make_initializable_iterator(new_prefetch_dataset) + new_batch = new_iter.get_next() + self.assertEqual(len(new_batch), 4) + + def test_ok_topo_sort_sliced_ops(self): + with tf.compat.v1.Graph().as_default(): + t1 = tf.constant(0) + t2 = tf.identity(t1) + t3 = tf.identity(t2) + ops = {t3.op, t2.op, t1.op} + + topo_sorted_ops = NoGradSubgraphSlicer._topo_sort_sliced_ops(ops) + self.assertEqual(topo_sorted_ops, [t1.op, t2.op, t3.op]) + + def test_ok_clone_subgraph_into_funcgraph(self): + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + in_op = ids.op + inner_tensor = tf.identity(ids) + inner_op = inner_tensor.op + out_op = tf.identity(inner_tensor).op + + sliced_ops = {inner_op} + in_op_to_edge_ops = {in_op: {inner_op}} + out_op_to_edge_ops = {out_op: {inner_op}} + + with patch.object(tf.compat.v1.Graph, "get_tensor_by_name", return_value=tf.identity(inner_tensor)): + new_batch = MockNoGradSubgraphSlicer()._clone_subgraph_into_funcgraph( + sliced_ops, in_op_to_edge_ops, out_op_to_edge_ops, batch + ) + self.assertEqual(len(new_batch), 4) + + def test_ok_make_new_get_next(self): + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + old_get_next = ids.op + new_dataset = gen_mock_dataset().prefetch(0) + + new_get_next = MockNoGradSubgraphSlicer()._make_new_get_next(old_get_next, new_dataset) + self.assertIsNotNone(new_get_next) + + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_one_shot_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + old_get_next = ids.op + new_dataset = gen_mock_dataset().prefetch(0) + + new_get_next = MockNoGradSubgraphSlicer()._make_new_get_next(old_get_next, new_dataset) + self.assertIsNotNone(new_get_next) + + +class LookupSubGraphSlicerTestCase(unittest.TestCase): + def test_ok_find_all_tgt_ops(self): + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + inner_tensor = tf.identity(ids) + tf.identity(inner_tensor) + + all_tgt_ops = LookupSubgraphSlicer(op_types=["Identity"])._find_all_tgt_ops() + self.assertEqual(len(all_tgt_ops), 2) + + @patch.multiple( + "mx_rec.core.emb.base_sparse_embedding.BaseSparseEmbedding", get_anchor_attribute=Mock(return_value=True) + ) + def test_ok_find_sliceable_tgt_ops(self): + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + inner_tensor = tf.identity(ids) + lookup_key = tf.identity(inner_tensor) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, lookup_key) + + all_tgt_ops = LookupSubgraphSlicer(op_types=["Identity"])._find_sliceable_tgt_ops() + self.assertEqual(len(all_tgt_ops), 2) + + +class OrphanLookupKeySlicerTestCase(unittest.TestCase): + @patch.multiple("mx_rec.graph.slicers.utils", export_pb_graph=Mock(return_value=None)) + def test_ok_slice_ops(self): + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + inner_tensor = tf.constant(0, dtype=ids.dtype, shape=ids.shape) + lookup_key = tf.identity(inner_tensor) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, lookup_key) + + sliceable_ops = {inner_tensor.op} + OrphanLookupKeySlicer()._slice_ops(sliceable_ops, is_training=False) + + g = tf.compat.v1.get_default_graph() + prefetch_datasets = [op for op in g.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name] + self.assertEqual(len(prefetch_datasets), 2) + + @patch.multiple( + "mx_rec.core.emb.base_sparse_embedding.BaseSparseEmbedding", get_anchor_attribute=Mock(return_value=True) + ) + def test_ok_find_sliceable_tgt_ops(self): + with tf.compat.v1.Graph().as_default(): + prefetch_dataset = gen_mock_dataset().prefetch(0) + iterator = tf.compat.v1.data.make_initializable_iterator(prefetch_dataset) + batch = iterator.get_next() + ids = batch["mock_ids"] + + inner_tensor = tf.constant(0, dtype=ids.dtype, shape=ids.shape) + lookup_key = tf.identity(inner_tensor) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, lookup_key) + + all_tgt_ops = OrphanLookupKeySlicer()._find_sliceable_tgt_ops() + self.assertEqual(len(all_tgt_ops), 2) diff --git a/tests/mx_rec/graph/test_utils.py b/tests/mx_rec/graph/test_utils.py index 5a4efffc3d46fb08d1ba6a7fcb91c411d4f9e772..7aead90e792308dede15363f9798b042d426e227 100644 --- a/tests/mx_rec/graph/test_utils.py +++ b/tests/mx_rec/graph/test_utils.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================== -import sys import os import pathlib import shutil @@ -24,42 +23,45 @@ from unittest import TestCase import tensorflow as tf from tensorflow import Tensor, TensorSpec + from mx_rec.constants.constants import ASCAnchorAttr from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding from mx_rec.graph.utils import ( - check_input_list, + find_trans_dataset, find_parent_op, + find_make_iterator_op, + find_target_instance_dataset, + upward_bfs_op, + check_and_force_list, check_cutting_points, export_pb_graph, make_sorted_key_to_tensor_list, replace_anchor_vec, ) +from tests.mx_rec.graph.mock_dataset import gen_mock_dataset -class CheckInputListTest(TestCase): - def tearDown(self): - tf.compat.v1.reset_default_graph() - - def test_ok_single_object(self): - mock_obj = "obj" - obj_type = str - - checked_objs = check_input_list(mock_obj, obj_type) - self.assertEqual([mock_obj], checked_objs) +class FindTransDatasetTest(TestCase): + def setUp(self) -> None: + self._graph = tf.compat.v1.get_default_graph() - def test_ok_object_list(self): - mock_objs = ["obj1", "obj2", "ojb3"] - obj_type = str + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() - checked_cutting_points = check_input_list(mock_objs, obj_type) - self.assertEqual(mock_objs, checked_cutting_points) + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op - def test_err_inconsistent_object_and_type(self): - mock_objs = ["obj1", "obj2", "ojb3"] - obj_type = Tensor + found_dataset_op = find_trans_dataset(self._graph, mock_get_next_op) + self.assertEqual(found_dataset_op.type, "OptimizeDataset") - with self.assertRaises(ValueError): - check_input_list(mock_objs, obj_type) + def test_err_invalid_op_type(self): + mock_get_next_op = tf.zeros(shape=(4096, 8)).op + with self.assertRaises(TypeError): + find_trans_dataset(self._graph, mock_get_next_op) class FindParentOpTest(TestCase): @@ -76,6 +78,64 @@ class FindParentOpTest(TestCase): self.assertEqual([mock_parent_op], parent_op) +class FindMakeIteratorOpTest(TestCase): + def setUp(self) -> None: + self._graph = tf.compat.v1.get_default_graph() + + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + found_iter_op = find_make_iterator_op(self._graph, mock_ids) + self.assertEqual(found_iter_op.type, "MakeIterator") + + def test_err_no_tgt_dataset_op(self): + mock_ids = tf.zeros(shape=(4096, 8)) + with self.assertRaises(ValueError): + find_make_iterator_op(self._graph, mock_ids) + + +class FindTargetInstanceDatasetTest(TestCase): + def setUp(self) -> None: + self._graph = tf.compat.v1.get_default_graph() + + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_err_no_target_dataset_instance(self): + with self.assertRaises(LookupError): + find_target_instance_dataset(self._graph, None) + + +class UpwardBFSOpTest(TestCase): + def setUp(self) -> None: + self._graph = tf.compat.v1.get_default_graph() + + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_base_op = tf.identity(mock_ids).op + + found_tgt_dataset_op = upward_bfs_op(base_ops=mock_base_op, tgt_op_type="IteratorGetNext") + self.assertEqual(found_tgt_dataset_op, mock_ids.op) + + def test_err_no_tgt_op_type(self): + mock_ids = tf.zeros(shape=(4096, 8)) + mock_base_op = mock_ids.op + with self.assertRaises(ValueError): + upward_bfs_op(base_ops=mock_base_op, tgt_op_type="IteratorGetNext") + + class CheckCuttingPointsTest(TestCase): def setUp(self): self._generator_iter_times = 3 @@ -98,6 +158,32 @@ class CheckCuttingPointsTest(TestCase): check_cutting_points(mock_cutting_point_list) +class CheckAndForceListTest(TestCase): + def tearDown(self): + tf.compat.v1.reset_default_graph() + + def test_ok_single_object(self): + mock_obj = "obj" + obj_type = str + + checked_objs = check_and_force_list(mock_obj, obj_type) + self.assertEqual([mock_obj], checked_objs) + + def test_ok_object_list(self): + mock_objs = ["obj1", "obj2", "ojb3"] + obj_type = str + + checked_cutting_points = check_and_force_list(mock_objs, obj_type) + self.assertEqual(mock_objs, checked_cutting_points) + + def test_err_inconsistent_object_and_type(self): + mock_objs = ["obj1", "obj2", "ojb3"] + obj_type = Tensor + + with self.assertRaises(ValueError): + check_and_force_list(mock_objs, obj_type) + + class ExportPBGraphTest(TestCase): def setUp(self) -> None: self._dir_name = "./export_graph" @@ -162,7 +248,7 @@ class ReplaceAnchorVecTest(TestCase): anchor_vec_output = tf.identity(anchor_vec, name="anchor_vec_output") BaseSparseEmbedding.anchor_tensor_specs[mock_cutting_point][mock_attribute] = anchor_vec - replace_anchor_vec(mock_cutting_point, mock_attribute, mock_anchor) + replace_anchor_vec(tf.compat.v1.get_default_graph(), mock_cutting_point, mock_attribute, mock_anchor) self.assertEqual(anchor_vec_output.op.inputs[0], mock_anchor) diff --git a/tests/mx_rec/saver/sparse_embedding_mock.py b/tests/mx_rec/saver/sparse_embedding_mock.py index 03df14664f197fc89b6b24eb7cfac02788ba6bef..7f7d437d74991711f34db372677cef23062af8c0 100644 --- a/tests/mx_rec/saver/sparse_embedding_mock.py +++ b/tests/mx_rec/saver/sparse_embedding_mock.py @@ -29,11 +29,4 @@ class SparseEmbeddingMock: self.emb_size = 4 self.is_hbm = host_vocab_size == 0 self.host_vocabulary_size = host_vocab_size - self.optimizer = dict() self.use_dynamic_expansion = False - - def set_optimizer(self, key, state_dict): - if key in self.optimizer: - raise ValueError(f"Optimizer {key} has been set for hash table {self.table_name}") - - self.optimizer[key] = state_dict diff --git a/tests/mx_rec/saver/test_saver.py b/tests/mx_rec/saver/test_saver.py index 60c40a2153bccc306bc7c26e0983d2b5ef7a5f94..53066038a17e04c79964b93c492fac447e0f6cff 100644 --- a/tests/mx_rec/saver/test_saver.py +++ b/tests/mx_rec/saver/test_saver.py @@ -23,6 +23,7 @@ import tensorflow as tf from mx_rec.saver.saver import Saver from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.util.initialize import ConfigInitializer from tests.mx_rec.core.mock_class import MockConfigInitializer from tests.mx_rec.saver.sparse_embedding_mock import SparseEmbeddingMock @@ -40,8 +41,8 @@ class TestSaver(unittest.TestCase): @mock.patch.multiple("mx_rec.saver.saver", get_rank_id=mock.MagicMock(return_value=0), - get_local_rank_size=mock.MagicMock(return_value=1), - set_optimizer_info=mock.MagicMock(return_value=None)) + get_rank_size=mock.MagicMock(return_value=1), + get_local_rank_size=mock.MagicMock(return_value=1)) @mock.patch("mx_rec.saver.saver.ConfigInitializer") def test_save_and_load_is_consistent(self, saver_config_initializer): mock_config_initializer = \ @@ -60,18 +61,18 @@ class TestSaver(unittest.TestCase): self.saver = Saver() with tf.compat.v1.Session(graph=self.graph) as sess: - embedding_directory = "./sparse-model/test_table/embedding" + embedding_directory = "./sparse-model-1/test_table/embedding" data_file = os.path.join(embedding_directory, "slice.data") attribute_file = os.path.join(embedding_directory, "slice.attribute") sess.run(tf.global_variables_initializer()) origin_embedding = sess.run(self.var)[[0, 1, 4, 6, 8], :] - self.saver.save(sess) + self.saver.save(sess, save_path="model-1") self.assertTrue(os.path.exists(embedding_directory), "embedding目录已创建") self.assertTrue(os.path.exists(data_file), "embedding的data文件存储成功") self.assertTrue(os.path.exists(attribute_file), "embedding的attribute文件存储成功") - tf.io.gfile.rmtree("./sparse-model") + tf.io.gfile.rmtree("./sparse-model-1") def build_graph(self): self.graph = tf.compat.v1.Graph() @@ -86,7 +87,6 @@ class TestSaver(unittest.TestCase): optim_v_tensor = emb_initializer(self.shape) self.optimizer_v = tf.compat.v1.get_variable(self.optim_v_name, trainable=False, initializer=optim_v_tensor) - table_instance.set_optimizer("LazyAdam", {"momentum": self.optimizer_m, "velocity": self.optimizer_v}) tf.compat.v1.add_to_collection(ASCEND_GLOBAL_HASHTABLE_COLLECTION, self.var) return self.graph diff --git a/tests/mx_rec/util/communication/test_hccl_mgmt.py b/tests/mx_rec/util/communication/test_hccl_mgmt.py index f02570227ea36958d1334f49984c2cc2b6587cf2..870f8a3aedf7b84dd022fc06788f681e02178da2 100644 --- a/tests/mx_rec/util/communication/test_hccl_mgmt.py +++ b/tests/mx_rec/util/communication/test_hccl_mgmt.py @@ -104,16 +104,6 @@ class HCCLMGMTTest(unittest.TestCase): with self.assertRaises(ValueError): rank_to_device_dict, local_rank_size = parse_hccl_json() - def test_get_device_list(self): - device_list = get_device_list("0-7") - self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7], device_list) - device_list = get_device_list("0-3, 8-11") - self.assertEqual([0, 1, 2, 3, 8, 9, 10, 11], device_list) - with self.assertRaises(ValueError): - device_list = get_device_list("7-5, 9, 10") - with self.assertRaises(ValueError): - device_list = get_device_list("17") - if __name__ == '__main__': unittest.main() diff --git a/tests/mx_rec/util/test_variable.py b/tests/mx_rec/util/test_variable.py index c72ed9dcf284f6dd59dfa1bd8b78196bf9782f63..a3370e84b9fd309909b1e480fb4bcaaadd51a8f5 100644 --- a/tests/mx_rec/util/test_variable.py +++ b/tests/mx_rec/util/test_variable.py @@ -21,7 +21,7 @@ from unittest.mock import patch import tensorflow as tf from mx_rec.util.global_env_conf import global_env -from mx_rec.util.variable import check_and_get_config_via_var +from mx_rec.util.variable import get_config_via_var from mx_rec.util.variable import get_dense_and_sparse_variable from tests.mx_rec.core.mock_class import MockConfigInitializer @@ -29,7 +29,6 @@ from tests.mx_rec.core.mock_class import MockConfigInitializer class MockTableInstance: def __init__(self): self.is_hbm = False - self.optimizer = False @patch.multiple( @@ -44,10 +43,8 @@ class VariableTest(unittest.TestCase): """ self.cm_worker_size = global_env.cm_worker_size self.cm_chief_device = global_env.cm_chief_device - self.ascend_visible_devices = global_env.ascend_visible_devices global_env.cm_worker_size = "8" global_env.cm_chief_device = "0" - global_env.ascend_visible_devices = "0-7" def tearDown(self): """ @@ -56,7 +53,6 @@ class VariableTest(unittest.TestCase): """ global_env.cm_worker_size = self.cm_worker_size global_env.cm_chief_device = self.cm_chief_device - global_env.ascend_visible_devices = self.ascend_visible_devices @mock.patch("mx_rec.util.variable.ConfigInitializer") def test_get_dense_and_sparse_variable(self, variable_config_initializer): @@ -76,14 +72,6 @@ class VariableTest(unittest.TestCase): self.assertTrue(result_run) tf.reset_default_graph() - @mock.patch("mx_rec.util.variable.ConfigInitializer") - def test_check_and_get_config_via_var_when_environment_error(self, variable_config_initializer): - mock_config_initializer = MockConfigInitializer(var=MockTableInstance()) - variable_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - - with self.assertRaises(EnvironmentError): - self.assertEqual(MockTableInstance(), check_and_get_config_via_var("1", "optimize")) - if __name__ == '__main__': unittest.main() diff --git a/tests/run_python_dt.sh b/tests/run_python_dt.sh old mode 100644 new mode 100755 index f29bf7b5f001f0149de7ddd97a645a1cdfb4a73e..475fd788db119206f49ca5771b6922ef35f01795 --- a/tests/run_python_dt.sh +++ b/tests/run_python_dt.sh @@ -26,7 +26,7 @@ if [ $ARCH == "aarch64" ]; then fi # build mxRec and get output directory -bash "$TOP_PATH"/build/build_tf1_with_opensource.sh +bash "$TOP_PATH"/build/build_tf1.sh # create libasc directory and copy so files into it cd "$TOP_PATH"/mx_rec @@ -36,7 +36,7 @@ cd - # set environment variable export PYTHONPATH="${TOP_PATH}"/output:$PYTHONPATH -export LD_LIBRARY_PATH="${TOP_PATH}"/output:/usr/local/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH="${TOP_PATH}"/output:/usr/local/lib:"${TOP_PATH}"/mx_rec/libasc:$LD_LIBRARY_PATH rm -rf result mkdir -p result diff --git a/tools/atomic/sparse_lookup.py b/tools/atomic/sparse_lookup.py index 570c683e5934daa1d688708abad27639f347f73f..73ff7f330252f71f92acf4305b283b5d23bc8e3d 100644 --- a/tools/atomic/sparse_lookup.py +++ b/tools/atomic/sparse_lookup.py @@ -28,7 +28,6 @@ from sparse_ops.config import set_ascend_env USE_PIPELINE_TEST = False USE_STATIC = False -USE_HOT = False USE_EXPANSION = False from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET @@ -171,7 +170,7 @@ if __name__ == '__main__': host_vocab_size = 0 init(True, rank_id=rank_id, rank_size=local_rank_size, train_interval=100, eval_steps=-1, - prefetch_batch_number=1, use_dynamic=0, use_hot=1, use_dynamic_expansion=0) + prefetch_batch_number=1, use_dynamic=0, use_dynamic_expansion=0) tf.disable_eager_execution() ###################################### diff --git a/tools/atomic/sparse_lookup_with_grad.py b/tools/atomic/sparse_lookup_with_grad.py index 3d7d37e5bd89764df52b00bbbfcd8b254b094938..ea80bce389e63c0be84f8c6a6acd4dbb26766046 100644 --- a/tools/atomic/sparse_lookup_with_grad.py +++ b/tools/atomic/sparse_lookup_with_grad.py @@ -28,7 +28,6 @@ from sparse_ops.config import set_ascend_env USE_PIPELINE_TEST = False USE_STATIC = False -USE_HOT = False USE_EXPANSION = False @@ -173,7 +172,7 @@ if __name__ == '__main__': host_vocab_size = 0 init(True, rank_id=rank_id, rank_size=local_rank_size, train_interval=100, eval_steps=-1, - prefetch_batch_number=1, use_dynamic=0, use_hot=1, use_dynamic_expansion=0) + prefetch_batch_number=1, use_dynamic=0, use_dynamic_expansion=0) tf.disable_eager_execution() ###################################### @@ -204,7 +203,6 @@ if __name__ == '__main__': emb_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=0), device_vocabulary_size=dev_vocab_size * local_rank_size, - optimizer_list=sparse_optimizer_list, mode=MxRecMode.mapping("ASC")) sparse_variables = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) diff --git a/tools/graph_partition/gen_config.py b/tools/graph_partition/gen_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd69de36356896e9ffa8bc746fb4344fa8fa63c --- /dev/null +++ b/tools/graph_partition/gen_config.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import os + +import tensorflow as tf +from graph_partition import GraphPartitioner + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument("--model_path", type=str, default="./") + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--output_filename", type=str, default="config.cfg") + args = parser.parse_args() + + signature_def = "serving_default" + + # 模型配置 + embedding_lookup_op_type = ["Sum"] + heavy_load_ops = ["MatMul"] # 必须下沉的算子(暂时没用到) + use_whole_graph = False + partition_to_first_heavy_load = False + ######################################################### + + output_filepath = os.path.join(args.output_path, args.output_filename) + + with tf.compat.v1.Session() as sess: + meta_graph = tf.compat.v1.saved_model.loader.load( + sess, ["serve"], args.model_path + ) + ops = sess.graph.get_operations() + graph_partitioner = GraphPartitioner() + + graph_partitioner.graph = sess.graph + graph_partitioner.signature_def = meta_graph.signature_def.get(signature_def) + graph_partitioner.set_embedding_lookup_op_type(embedding_lookup_op_type) + + inputs, outputs = graph_partitioner.get_sub_graph() + + res_string = "[[" + inputs + "," + outputs + "]]" + + ori_test = open("template.cfg") + template = ori_test.read() + output = template.replace("#value@in_out_pair#", res_string) + if os.path.exists(output_filepath): + os.remove(output_filepath) + + # open text file + text_file = os.fdopen(os.open(output_filepath, os.O_WRONLY | os.O_CREAT, 0o666, "w")) + + # write string to file + n = text_file.write(output) + + # close file + text_file.close() + ori_test.close() diff --git a/tools/graph_partition/graph_partition.py b/tools/graph_partition/graph_partition.py new file mode 100644 index 0000000000000000000000000000000000000000..8ebfbfba07f59ff769fa7869178720a9bde4876a --- /dev/null +++ b/tools/graph_partition/graph_partition.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from tensorflow.contrib import graph_editor as ge + + +class GraphPartitioner: + def __init__(self): + self.signature_def = None + self.graph = None + self.op_node_lookup = dict() + self.input_op_nodes = [] + self.output_op_nodes = [] + self.tensor_node_lookup = dict() + self.heavy_load_ops = [] + self.embedding_lookup_op_type = None + self.first_heavy_load_on_sparse_path = set() + self.first_op_after_lookup = [] + self.seen = set() + self.post_out = set() + self.partition_to_first_heavy_load = False + + self.sparse_lookup_ops = [] + self.sparse_lookup_tensors = [] + self.input_nodes = [] + self.output_nodes = [] + + @staticmethod + def has_gray_downstreams(op): + gray_list = ["DynamicPartition"] + down_ops = ge.get_forward_walk_ops([op]) + for op in down_ops: + if op.type in gray_list: + return True + return False + + def set_embedding_lookup_op_type(self, s): + self.embedding_lookup_op_type = s + + def get_sub_graph(self): + for op in self.graph.get_operations(): + if self._is_embedding_lookup(op): + self.sparse_lookup_ops.append(op) + if not self.sparse_lookup_ops: + for op in self.graph.get_operations(): + is_top_op = True + for op1 in self.graph.get_operations(): + for tensor in op1.outputs: + if tensor in op.inputs: + is_top_op = False + break + if not is_top_op: + break + if is_top_op: + self.sparse_lookup_ops.append(op) + check_ops = self.sparse_lookup_ops + self.sparse_lookup_ops = [] + for op in check_ops: + if not self.has_gray_downstreams(op): + self.sparse_lookup_ops.append(op) + self.sparse_lookup_tensors.extend(op.outputs) + + for op in self.graph.get_operations(): + for tensor in self.sparse_lookup_tensors: + if tensor in op.inputs: + self.input_nodes.append(op) + for k, v in self.signature_def.outputs.items(): + op_name = ( + str(v) + .split("\n")[0] + .replace(" ", "") + .replace('"', "") + .split(":")[1] + .split(":")[0] + ) + for op in self.graph.get_operations(): + if op.name == op_name: + self.output_nodes.append(op) + + float_ups = [] + to_expand = [] + in_str = [] + + for op in self.input_nodes: + if op.type not in float_ups: + if op.name not in in_str: + in_str.append(op.name) + else: + to_expand.append(op) + + while to_expand: + candidates = [] + for top in to_expand: + for op in self.graph.get_operations(): + for tensor in op.inputs: + if tensor in top.outputs: + candidates.append(op) + to_expand = [] + for op in candidates: + if op.type not in float_ups: + if op.name not in in_str: + in_str.append(op.name) + else: + to_expand.append(op) + return str(in_str), str([op.name for op in self.output_nodes]) + + def _is_embedding_lookup(self, op): + if op.type in self.embedding_lookup_op_type: + return True + + return False + + def _check_op_status(self): + unseen_list = [] + for name, op_node in self.op_node_lookup.items(): + if not op_node.seen: + unseen_list.append(name) + return unseen_list diff --git a/tools/graph_partition/template.cfg b/tools/graph_partition/template.cfg new file mode 100644 index 0000000000000000000000000000000000000000..fef30a9b670923e443484b8ea14b0a4b2ea3d4c5 --- /dev/null +++ b/tools/graph_partition/template.cfg @@ -0,0 +1,59 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +platform_configs { + key: "tensorflow" + value { + source_adapter_config { + [type.googleapis.com/tensorflow.serving.SaveModelBundleSourceAdapterConfig] { + legacy_config { + session_config { + graph_options { + rewrite_options { + custom_optimizers { + name: "NpuOptimizer" + parameter_map: { + key:"use_off_line" + value:{ + b:true + } + } + parameter_map: { + key:"mix_compile_mode" + value:{ + b:true + } + } + parameter_map: { + key:"variable_placement" + value:{ + s:"Host" + } + } + parameter_map: { + key:"graph_run_mode" + value:{ + i:0 + } + } + parameter_map: { + key:"precision_mode" + value:{ + s:"must_keep_origin_dtype" + } + } + parameter_map: { + key:"in_out_pair" + value:{ + s:"#value@in_out_pair#" + } + } + } + remapping: OFF + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/tools/model_convert/model_convert.py b/tools/model_convert/model_convert.py index 7608917a42b7a1428347f0502b6767c416c4c585..eb2432db587cadbfb0ae3ef41450803606de99ad 100644 --- a/tools/model_convert/model_convert.py +++ b/tools/model_convert/model_convert.py @@ -222,9 +222,9 @@ class ModelConverter: for _, dirs, _ in os.walk(check_dir): model_dirs.append(dirs) if not self._is_ddr and "DDR" in model_dirs[0]: - raise ValueError(f"wrong mode choose! you choose hbm mode, however ddr dir exists. ") + raise ValueError("wrong mode choose! you choose hbm mode, however ddr dir exists. ") if self._is_ddr and "DDR" not in model_dirs[0]: - raise ValueError(f"wrong mode choose! you choose ddr mode, however ddr dir not exists. ") + raise ValueError("wrong mode choose! you choose ddr mode, however ddr dir not exists. ") def get_attribute_and_data_file(table_path): diff --git a/tools/perfrec-python/README.md b/tools/perfrec-python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ddc7e114db5abe818447c300c3190b136940848d --- /dev/null +++ b/tools/perfrec-python/README.md @@ -0,0 +1,81 @@ +## perf.py +``` +usage: perf.py [-h] --perf_data PERF_DATA --flamegraph_path FLAMEGRAPH_PATH + [--perf_bin PERF_BIN] [--output_svg OUTPUT_SVG] + +Generate a Flamegraph from perf.data. + +optional arguments: + -h, --help show this help message and exit + --perf_data PERF_DATA + Path to the perf.data file. + --flamegraph_path FLAMEGRAPH_PATH + Path to the Flamegraph Perl scripts directory. + --perf_bin PERF_BIN Path to perf exacutable binary file. (default: perf) + --output_svg OUTPUT_SVG + Path to the output SVG file. (default: flamegraph.svg) +``` +#### 使用示例 + +参考以下脚本使用`perf`采集数据。 +```bash +pid=$(top -b -n 1 | head -n 8 | tail -n 1 | awk '{print $1}') +if [ -z "$pid" ];then + echo "未获取到进程ID" + exit 1 +fi +perf record -F 99 -p $pid -a -g -- sleep 60 +if [ $? -ne 0 ]; then + echo "perf record执行失败" + exit 1 +fi +echo "perf.data 采集完成" +``` + +使用本工具生成火焰图和耗时函数分析。 +```bash +python perf.py --perf_data perf.data --flamegraph_path /ws/FlameGraph +``` +#### 可选配置 +```toml +# config.toml + +[perf] +# Filter percentage of time cost +threshold = 0.05 +# Ignore function list +ignores = ["[libc.so.6]"] +``` + +## fusion_tracing.py +``` +usage: fusion_tracing.py [-h] --debug_log DEBUG_LOG + [--msprof_output MSPROF_OUTPUT] + +Generate CPU/NPU fusion tracing json. + +optional arguments: + -h, --help show this help message and exit + --debug_log DEBUG_LOG + MxRec DEBUG level log flie path. + --msprof_output MSPROF_OUTPUT + msprof output path. +``` +#### 使用示例 +```bash +# only cpu +python fusion_tracing.py --debug_log ../../example/demo/little_demo/temp.log +# cpu + npu +python fusion_tracing.py --debug_log ../../example/demo/little_demo/temp.log --msprof_output ../../example/demo/little_demo/msprof +``` +#### 可选配置 +```toml +# config.toml + +[mxrec] +# Pipe name and time cost name +key_process = ["getBatchData", "getAndProcess"] +process_emb_info = ["getAndSendTensors"] +lookup_swap_addr = ["lookupAddrs"] +embedding_recv = ["EmbeddingRecv", "EmbeddingUpdate", "SendH2DEmb"] +``` diff --git a/tools/perfrec-python/config.toml b/tools/perfrec-python/config.toml new file mode 100644 index 0000000000000000000000000000000000000000..8e15fd1db647bc10a88efff75c231674642db810 --- /dev/null +++ b/tools/perfrec-python/config.toml @@ -0,0 +1,27 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +[mxrec] +# Pipe name and time cost name +key_process = ["getBatchData", "getAndProcess"] +process_emb_info = ["getAndSendTensors"] +lookup_swap_addr = ["lookupAddrs"] +embedding_recv = ["EmbeddingRecv", "EmbeddingUpdate", "SendH2DEmb"] + +[perf] +# Filter percentage of time cost +threshold = 0.05 +# Ignore function list +ignores = ["[libc.so.6]"] diff --git a/tools/perfrec-python/fusion_tracing.py b/tools/perfrec-python/fusion_tracing.py new file mode 100644 index 0000000000000000000000000000000000000000..49900004f62789a83cc4a94bcefa5b6d2ca036e5 --- /dev/null +++ b/tools/perfrec-python/fusion_tracing.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import json +import logging +import os +import re +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Tuple + +import pandas as pd +import toml + + +class MxRecConfig: + """ + Configuration from `config.toml`. + """ + + def __init__(self, pipes: Dict[str, List[str]]): + self.pipes = pipes + self.func_to_pipe = defaultdict(str) + for pipe_name, event_list in self.pipes.items(): + for event in event_list: + self.func_to_pipe[event] = pipe_name + self.pipe_names = [name for name in pipes.keys()] + + +class MxRecEvent: + """ + Class to represent an MxRec event. + """ + + def __init__(self, log_line: str, event_name: str, pipe_id: int): + timestamp_s = get_timestamp(log_line) + duration_ms = get_duration(log_line, event_name) + process_id = get_process_id(log_line) + self.timestamp_start_us = timestamp_s * 1e6 - duration_ms * 1e3 + self.duration_us = duration_ms * 1e3 + self.timestamp_end_us = timestamp_s * 1e6 + self.process_id = process_id + self.name = event_name + self.pipe_id = pipe_id + + +@dataclass +class OpEvent: + """ + Class to represent an Op event. + """ + + device_id: int + op_name: str + op_type: str + task_type: str + start_timestamp: float + duration: float + + +def extract_mxrec_events( + log_path: str, config: MxRecConfig +) -> Dict[int, Dict[str, List[MxRecEvent]]]: + """ + Extracts MxRec events from the log file. + + Args: + log_path (str): Path to the log file. + config (MxRecConfig): Dictionary mapping event names to pipe names and other configs. + + Returns: + Dict[int, Dict[str, List[MxRecEvent]]]: Extracted MxRec events grouped by process ID and pipe. + """ + events: Dict[int, Dict[str, List[MxRecEvent]]] = defaultdict( + lambda: defaultdict(list) + ) + broken_lines = list() + event_names = config.func_to_pipe + pipe_names = config.pipe_names + pipe_ids = defaultdict(int) + for i, pipe in enumerate(pipe_names): + pipe_ids[pipe] = i + with open(log_path) as log: + for line in log: + for name, pipe in filter(lambda item: item[0] in line, event_names.items()): + try: + event = MxRecEvent(line, name, pipe_ids[pipe]) + events[event.process_id][pipe].append(event) + except RuntimeError: + broken_lines.append(line) + if broken_lines: + logging.warning("There are %d broken log lines", len(broken_lines)) + for line in broken_lines: + logging.warning(line) + return events + + +def extract_op_events(op_summary_path: str) -> List[OpEvent]: + """ + Extracts Op events from the CSV file. + + Args: + op_summary_path (str): Path to the op summary CSV file. + + Returns: + List[OpEvent]: List of extracted Op events. + """ + df = pd.read_csv(op_summary_path) + return [ + OpEvent( + row["Device_id"], + row["Op Name"], + row["OP Type"], + row["Task Type"], + row["Task Start Time(us)"], + row["Task Duration(us)"], + ) + for _, row in df.iterrows() + ] + + +def get_timestamp(log_line: str) -> float: + """ + Extracts the timestamp from a log line. + + Args: + log_line (str): A line from the log file. + + Returns: + float: The extracted timestamp as a float. + """ + pattern = r"\[(\d{4}/\d{1,2}/\d{1,2} \d{1,2}:\d{1,2}:\d{1,2}\.\d+)\]" + match = re.search(pattern, log_line) + if not match: + raise RuntimeError(f"there is no time in log: {log_line}") + date_time_str = match.group(1) + date_time_format = "%Y/%m/%d %H:%M:%S.%f" + # Parse the date-time string into a datetime object + date_time_obj = datetime.strptime(date_time_str, date_time_format) + # Convert the datetime object to a timestamp + return date_time_obj.timestamp() + + +def get_duration(log_line: str, event_name: str) -> float: + """ + Extracts the duration of an event from a log line. + + Args: + log_line (str): A line from the log file. + event_name (str): The name of the event. + + Returns: + int: The extracted duration in milliseconds. + """ + pattern = event_name + r".*:\s*(\d+)" + match = re.search(pattern, log_line) + if not match: + raise RuntimeError(f"there is no event: {event_name}, log: {log_line}") + duration_ms = match.group(1) + return float(duration_ms) + + +def get_process_id(log_line: str) -> int: + """ + Extracts the process ID from a log line. + + Args: + log_line (str): A line from the log file. + + Returns: + int: The extracted process ID. + """ + pattern = r"\[(\d+)\]" + match = re.search(pattern, log_line) + if not match: + raise RuntimeError(f"there is no process_id in log: {log_line}") + process_id = match.group(1) + return int(process_id) + + +def read_mxrec_config() -> MxRecConfig: + """ + Reads the MxRec configuration from a TOML file. + + Returns: + MxRecCofig: Configuration class. + """ + try: + config = toml.load("config.toml") + return MxRecConfig(config["mxrec"]) + except toml.TomlDecodeError as e: + raise RuntimeError("can not load config.toml") from e + + +@dataclass +class TracingMetaData: + """ + Class to represent metadata for tracing. + """ + + name: str + pid: int + tid: int + ph: str + args: Dict[str, Any] + + +class TracingMxRecEvent: + """ + Class to represent a traced MxRec event. + """ + + def __init__(self, mxrec_event: MxRecEvent): + self.name = mxrec_event.name + self.pid = mxrec_event.process_id + self.tid = get_fake_tid(self.pid, mxrec_event.pipe_id) + self.ts = mxrec_event.timestamp_start_us + self.dur = mxrec_event.duration_us + self.ph = "X" + self.args = {} + + +class TracingOpEvent: + """ + Class to represent a traced Op event. + """ + + def __init__(self, op_event: OpEvent, tid: int): + self.name = op_event.op_type + self.pid = get_op_pid(op_event) + self.tid = tid + self.ts = op_event.start_timestamp + self.dur = op_event.duration + self.ph = "X" + self.args = {"Op Name": op_event.op_name} + + +def get_metadata(processes: List[int], config: MxRecConfig) -> List[TracingMetaData]: + """ + Generates metadata for tracing processes and threads. + + Args: + processes (List[int]): List of process IDs. + config (MxRecConfig): Configuration class. + + Returns: + List[TracingMetaData]: List of tracing metadata. + """ + metadata = list() + pipes = config.pipe_names + for i, pid in enumerate(processes): + metadata1 = TracingMetaData( + "process_name", pid, 0, "M", {"name": f"MxRec process {i}"} + ) + metadata2 = TracingMetaData( + "process_sort_index", pid, 0, "M", {"sort_index": i} + ) + metadata.append(metadata1) + metadata.append(metadata2) + for pipe_i, pipe in enumerate(pipes): + pipe_metadata1 = TracingMetaData( + "thread_name", + pid, + get_fake_tid(pid, pipe_i), + "M", + {"name": f"{pipe} {pid}"}, + ) + pipe_metadata2 = TracingMetaData( + "thread_sort_index", + pid, + get_fake_tid(pid, pipe_i), + "M", + {"sort_index": pipe_i}, + ) + metadata.append(pipe_metadata1) + metadata.append(pipe_metadata2) + return metadata + + +def get_fake_tid(pid: int, pipe_id: int) -> int: + """ + Generates a fake thread ID based on process ID and pipe ID. + + Args: + pid (int): Process ID. + pipe_id (int): Pipe ID. + + Returns: + int: Fake thread ID. + """ + return pid * 10 + pipe_id + + +def get_op_pid(op_event: OpEvent) -> int: + """ + Gets the process ID for an Op event. + + Args: + op_event (OpEvent): An Op event. + + Returns: + int: Process ID. + """ + # add 100 avoiding confict with cpu pid(rand_id) + return 100 + op_event.device_id + + +def get_op_tracing(path: str) -> Tuple[List[TracingMetaData], List[TracingOpEvent]]: + """ + Generates tracing data for Op events. + + Args: + path (str): Path to the directory containing Op event summaries. + + Returns: + Tuple[List[TracingMetaData], List[TracingOpEvent]]: Metadata and tracing events. + """ + task_types = defaultdict(int) + pids = set() + tids = set() + metadata = list() + op_tracing = list() + + def new_process_metadata(pid, device_id): + metadata1 = TracingMetaData( + "process_name", pid, 0, "M", {"name": f"NPU {device_id}"} + ) + metadata2 = TracingMetaData( + "process_sort_index", pid, 0, "M", {"sort_index": pid} + ) + return [metadata1, metadata2] + + def new_thread_metadata(pid, tid, name): + metadata1 = TracingMetaData("thread_name", pid, tid, "M", {"name": f"{name}"}) + metadata2 = TracingMetaData( + "thread_sort_index", pid, tid, "M", {"sort_index": tid} + ) + return [metadata1, metadata2] + + for root, _, files in os.walk(path): + for file in files: + if ( + root.endswith("mindstudio_profiler_output") + and file.startswith("op_summary") + and file.endswith(".csv") + ): + file_path = os.path.join(root, file) + op_events = extract_op_events(file_path) + for event in op_events: + process_id = get_op_pid(event) + if process_id not in pids: + pids.add(process_id) + metadata.extend( + new_process_metadata(process_id, event.device_id) + ) + if event.task_type not in task_types: + task_id = len(task_types) + task_types[event.task_type] = task_id + tid = get_fake_tid(process_id, task_types[event.task_type]) + if tid not in tids: + tids.add(tid) + metadata.extend( + new_thread_metadata(process_id, tid, event.task_type) + ) + op_tracing.append(TracingOpEvent(event, tid)) + return metadata, op_tracing + + +def main(): + """ + Main function to parse arguments and generate tracing JSON. + """ + logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser( + description="Generate CPU/NPU fusion tracing json." + ) + parser.add_argument( + "--debug_log", help="MxRec DEBUG level log file path.", required=True + ) + parser.add_argument("--msprof_output", help="msprof output path.", required=False) + args = parser.parse_args() + + log_path = args.debug_log + tracing = list() + try: + config = read_mxrec_config() + mxrec_events = extract_mxrec_events(log_path, config) + tracing.extend(get_metadata(list(mxrec_events.keys()), config)) + except RuntimeError: + logging.error("Can not read config.toml, it will exit unsuccessfully.") + exit(1) + + for process in mxrec_events.values(): + for events in process.values(): + tracing.extend([TracingMxRecEvent(event) for event in events]) + + msprof_output_path = args.msprof_output + if msprof_output_path: + op_metadata, op_tracing = get_op_tracing(msprof_output_path) + tracing.extend(op_metadata) + tracing.extend(op_tracing) + + fd = os.open("mxrec_tracing.json", os.O_WRONLY | os.O_CREAT, 0o640) + with os.fdopen(fd, "w") as file: + json.dump(tracing, file, indent=4, default=lambda obj: obj.__dict__) + + +if __name__ == "__main__": + main() diff --git a/tools/perfrec-python/perf.py b/tools/perfrec-python/perf.py new file mode 100644 index 0000000000000000000000000000000000000000..34f688e90373055fa3bdece7b1aba5eb9338d324 --- /dev/null +++ b/tools/perfrec-python/perf.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import logging +import os +import subprocess +from collections import defaultdict +from typing import List + +import toml +from tabulate import tabulate + + +def generate_flamegraph( + perf_bin: str, perf_data: str, output_svg: str, flamegraph_path: str +) -> None: + """ + Generate a flamegraph from perf data. + + Args: + perf_data (str): Path to the perf.data file. + output_svg (str): Path to the output SVG file. + flamegraph_path (str): Path to the Flamegraph scripts directory. + """ + # Ensure perf script is available + try: + subprocess.run([perf_bin, "--version"], shell=False, check=True) + except subprocess.CalledProcessError: + logging.error("perf is not installed or not in PATH.") + return + + # Ensure Flamegraph scripts are available + stackcollapse_path = os.path.join(flamegraph_path, "stackcollapse-perf.pl") + flamegraph_script_path = os.path.join(flamegraph_path, "flamegraph.pl") + + if not os.path.isfile(stackcollapse_path) or not os.path.isfile( + flamegraph_script_path + ): + logging.error( + "Flamegraph scripts not found in the provided directory %s.", + flamegraph_path, + ) + return + + # Generate the folded stack output + folded_output = perf_data + ".folded" + fd = os.open(folded_output, os.O_WRONLY | os.O_CREAT, 0o640) + with os.fdopen(fd, "w") as f: + script_output = subprocess.run( + [perf_bin, "script", "-i", perf_data], + shell=False, + check=True, + stdout=subprocess.PIPE, + ) + subprocess.run( + [stackcollapse_path], + shell=False, + check=True, + input=script_output.stdout, + stdout=f, + ) + + # Generate the flamegraph + fd_svg = os.open(output_svg, os.O_WRONLY | os.O_CREAT, 0o640) + with os.fdopen(fd_svg, "w") as f: + subprocess.run( + [flamegraph_script_path, folded_output], shell=False, check=True, stdout=f + ) + + logging.info("Flamegraph generated at %s", output_svg) + + # Analyze the folded stack output + analyze_folded_stack(folded_output) + + +class CallStack: + def __init__(self): + self.count = 0 + self.call_stacks = [] + + def add_call_stacks(self, count: int, call_stack: str): + self.count += count + self.call_stacks.append(call_stack) + + +def analyze_folded_stack(folded_output: str) -> None: + """ + Analyzes the folded stack output to find functions with significant sample counts. + + Args: + folded_output (str): Path to the folded stack output file. + """ + + function_counts = defaultdict(CallStack) + total_count = 0 + + # Read the folded stack output + # Line of folded stack example: + # python3.7;[libascendalog.so];access;__sys_trace_return;prepare_creds 10101010 + with open(folded_output, "r") as f: + for line in f: + parts = line.strip().rsplit( + " ", 1 + ) # Use rsplit to handle function names with spaces + count = int(parts[-1]) + call_stack_str = parts[0] + stack = parts[0].split(";") + function_counts[stack[-1]].add_call_stacks(count, call_stack_str) + total_count += count + + config = read_config() + + # Filter and display functions with more than 5% total count + threshold = total_count * config.threshold + results = [ + (func, call_stack) + for func, call_stack in function_counts.items() + if call_stack.count >= threshold and func not in config.ignores + ] + + # Sort results by count in descending order + results.sort(key=lambda x: x[1].count, reverse=True) + + # Prepare data for tabulate + # Write call stacks to file + table_data = [] + fd_call_stacks = os.open("call_stacks.txt", os.O_WRONLY | os.O_CREAT, 0o640) + with os.fdopen(fd_call_stacks, "w") as f: + for func, call_stack in results: + percentage = ( + (call_stack.count / total_count) * 100 if total_count != 0 else 0 + ) + table_data.append( + [limit_line(func, 50), call_stack.count, f"{percentage:.2f}%"] + ) + stacks = [stk + "\n" for stk in call_stack.call_stacks] + f.writelines( + [ + f"func_name: {func}\n", + f"percentage: {percentage:.2f}%\n", + "call_stacks:\n", + ] + + stacks + + ["\n\n"] + ) + + # Print the results using tabulate + logging.info("\nFunctions with more than 5% of total samples:") + headers = ["Function", "Count", "Percentage"] + logging.info("\n%s", tabulate(table_data, headers=headers, tablefmt="grid")) + + +def limit_line(input_content: str, line_length: int) -> str: + """ + Limits the length of a line to a specified number of characters, adding line breaks if necessary. + + Args: + input_content (str): The input string. + line_length (int): The maximum line length. + + Returns: + str: The formatted string with line breaks. + """ + if line_length >= len(input_content): + return input_content + limited_str = "" + if line_length > 0: + limited_str = "\n".join( + input_content[i : i + line_length] + for i in range(len(input_content), line_length) + ) + return limited_str + + +class PerfConfig: + """ + Configuration from `config.toml`. + """ + + def __init__(self, ignores: List[str], threshold: float = 0.05): + self.ignores = set(ignores) + self.threshold = threshold + + +def read_config() -> PerfConfig: + """ + Reads configs related to `perf` from the configuration file. + + Returns: + PerfConfig: Configuration class. + """ + try: + config = toml.load("config.toml") + perf_config = config["perf"] + return PerfConfig(perf_config["ignores"], perf_config["threshold"]) + except toml.TomlDecodeError: + return PerfConfig(ignores=[]) + + +def main(): + """ + Main function to parse arguments and generate a flamegraph. + """ + logging.basicConfig(level=logging.INFO) + parser = argparse.ArgumentParser( + description="Generate a Flamegraph from perf.data." + ) + parser.add_argument( + "--perf_data", help="Path to the perf.data file.", required=True + ) + parser.add_argument( + "--flamegraph_path", + help="Path to the Flamegraph Perl scripts directory.", + required=True, + ) + parser.add_argument( + "--perf_bin", + help="Path to perf exacutable binary file. (default: perf)", + required=False, + default="perf", + ) + parser.add_argument( + "--output_svg", + help="Path to the output SVG file. (default: flamegraph.svg)", + required=False, + default="flamegraph.svg", + ) + args = parser.parse_args() + + generate_flamegraph( + args.perf_bin, args.perf_data, args.output_svg, args.flamegraph_path + ) + + +if __name__ == "__main__": + main()