From 5534b7b35e2d291a976ab4191b94167186da76ca Mon Sep 17 00:00:00 2001 From: yangzhen Date: Sun, 7 Apr 2024 14:55:40 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E5=90=8C=E6=AD=A5AccCTR=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/AccCTR/3rdparty/CMakeLists.txt | 14 + src/AccCTR/CMakeLists.txt | 16 +- src/AccCTR/README.md | 4 +- src/AccCTR/src/CMakeLists.txt | 8 +- src/AccCTR/src/common/util/error_code.h | 15 +- .../src/common/util/external_threader.h | 70 + src/AccCTR/src/embedding_cache/CMakeLists.txt | 27 + .../cache_manager/cache_manager.cpp | 379 ++++ .../cache_manager/cache_manager.h | 86 + src/AccCTR/src/embedding_cache/common.h | 159 ++ .../embedding_local_table/emb_local_table.cpp | 385 ++++ .../embedding_local_table/emb_local_table.h | 69 + .../constant_initializer.cpp | 56 + .../initializer/initializer.cpp | 56 + .../random_normal_initializer.cpp | 72 + .../truncated_normal_initializer.cpp | 80 + .../offset_mapper/address_mapper.h | 311 ++++ .../offset_mapper/mapper_base.h | 785 ++++++++ .../offset_mapper/offset_mapper.h | 209 +++ src/AccCTR/src/factory_impl.cpp | 11 + src/AccCTR/src/factory_impl.h | 2 + src/AccCTR/src/include/CMakeLists.txt | 2 +- src/AccCTR/src/include/embedding_cache.h | 295 +++ src/AccCTR/src/include/factory.h | 5 +- src/AccCTR/src/include/ock_ctr_common_def.h | 2 +- src/AccCTR/src/include/unique.h | 1 + src/AccCTR/src/unique/unique_func.cpp | 56 +- src/AccCTR/src/unique/unique_func.h | 237 +-- src/AccCTR/src/unique/unique_impl.cpp | 8 + src/AccCTR/src/unique/unique_impl.h | 2 +- src/AccCTR/tests/tools/create_fake_id.py | 6 - src/AccCTR/tests/ut/conf/toolchain.cmake | 24 + src/AccCTR/tests/ut/src/CMakeLists.txt | 26 +- src/AccCTR/tests/ut/src/common.h | 64 + src/AccCTR/tests/ut/src/emb_cache_test.cpp | 1653 +++++++++++++++++ src/AccCTR/tests/ut/src/emb_cache_test.h | 62 + src/AccCTR/tests/ut/src/unique_test.cpp | 53 +- src/AccCTR/tests/ut/src/unique_test.h | 16 - 38 files changed, 5154 insertions(+), 172 deletions(-) create mode 100644 src/AccCTR/src/embedding_cache/CMakeLists.txt create mode 100644 src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp create mode 100644 src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h create mode 100644 src/AccCTR/src/embedding_cache/common.h create mode 100644 src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp create mode 100644 src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h create mode 100644 src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp create mode 100644 src/AccCTR/src/embedding_cache/initializer/initializer.cpp create mode 100644 src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp create mode 100644 src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp create mode 100644 src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h create mode 100644 src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h create mode 100644 src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h create mode 100644 src/AccCTR/src/include/embedding_cache.h create mode 100644 src/AccCTR/tests/ut/conf/toolchain.cmake create mode 100644 src/AccCTR/tests/ut/src/common.h create mode 100644 src/AccCTR/tests/ut/src/emb_cache_test.cpp create mode 100644 src/AccCTR/tests/ut/src/emb_cache_test.h diff --git a/src/AccCTR/3rdparty/CMakeLists.txt b/src/AccCTR/3rdparty/CMakeLists.txt index a17e472c..3a05f585 100644 --- a/src/AccCTR/3rdparty/CMakeLists.txt +++ b/src/AccCTR/3rdparty/CMakeLists.txt @@ -1,3 +1,17 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + message("build mode " ${BUILD_MODE}) set(PLATFORM_UTILITIES_3RDPARTY_SOURCE_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) diff --git a/src/AccCTR/CMakeLists.txt b/src/AccCTR/CMakeLists.txt index 0cb63176..60e2d638 100644 --- a/src/AccCTR/CMakeLists.txt +++ b/src/AccCTR/CMakeLists.txt @@ -23,8 +23,6 @@ if (${BUILD_MODE} MATCHES "release") -Wall -fPIC -fms-extensions - -Wno-unused-parameter - -Wno-unused-function -Wunused-variable -Wunused-value -Wcast-align @@ -47,8 +45,6 @@ elseif (${BUILD_MODE} MATCHES "debug") -Wall -fPIC -fms-extensions - -Wno-unused-parameter - -Wno-unused-function -Wunused-variable -Wunused-value -Winvalid-pch @@ -67,8 +63,6 @@ elseif (${BUILD_MODE} MATCHES "ut") -Wall -fPIC -fms-extensions - -Wno-unused-parameter - -Wno-unused-function -Wunused-variable -Wunused-value -Winvalid-pch @@ -79,10 +73,6 @@ elseif (${BUILD_MODE} MATCHES "ut") -Wfloat-equal -Wextra -std=c++17 - #-fsanitize=address - #-fno-omit-frame-pointer - #-fstack-protector-all - #-fstack-protector-strong ) else () message(FATAL_ERROR "======BUILD_MODE not found") @@ -100,7 +90,6 @@ elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64") ${CXX_FLAGS} -msse2 -mavx - #-w ) else () message(FATAL_ERROR "don't support ${CMAKE_HOST_SYSTEM_PROCESSOR}") @@ -110,6 +99,11 @@ set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) message(===============${OCK_CTR_PLATFORM_UTIL_DIR}) include_directories(${OCK_CTR_PLATFORM_UTIL_DIR}/securec/include) +include_directories( + ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/src/embedding_cache +) + add_subdirectory(3rdparty) add_subdirectory(src) diff --git a/src/AccCTR/README.md b/src/AccCTR/README.md index 1a394699..1b25534d 100644 --- a/src/AccCTR/README.md +++ b/src/AccCTR/README.md @@ -6,4 +6,6 @@ 2、bash build.sh debug //编译debug -3、bash build.sh ut //编译并运行ut,覆盖率在tests/build/cov/gen目录下 +3、编译和运行UT: + (1)bash build.sh ut //编译ut,覆盖率在tests/build/cov/gen目录下 + (2)cd build && bash build_test.sh ut //进入到build目录下并运行ut \ No newline at end of file diff --git a/src/AccCTR/src/CMakeLists.txt b/src/AccCTR/src/CMakeLists.txt index 09da4670..5aaa168d 100644 --- a/src/AccCTR/src/CMakeLists.txt +++ b/src/AccCTR/src/CMakeLists.txt @@ -23,12 +23,16 @@ set(OUTPUT ${PROJECT_SOURCE_DIR}/output) set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) - +add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) # must set this option otherwise pybind will not find embCache symbol if (${BUILD_MODE} MATCHES "ut") add_compile_options(-ftest-coverage -fprofile-arcs) link_libraries(gcov) endif (${BUILD_MODE} MATCHES "ut") +if (${BUILD_MODE} MATCHES "fuzz") + add_compile_options(-ftest-coverage -fprofile-arcs -fdump-rtl-expand) + link_libraries(gcov asan) +endif (${BUILD_MODE} MATCHES "fuzz") message("include : " ${OCK_CTR_SRC_INCLUDE_DIR}) @@ -37,6 +41,7 @@ set(LIB_HW_SECURE ${OCK_CTR_PLATFORM_UTIL_DIR}/securec/lib/libsecurec.so) add_subdirectory(include) add_subdirectory(common) add_subdirectory(unique) +add_subdirectory(embedding_cache) file(GLOB_RECURSE CTR_SRC factory_impl.cpp) @@ -52,6 +57,7 @@ target_include_directories(_ock_ctr_common target_link_libraries(_ock_ctr_common PUBLIC -Wl,--start-group unique + embedding_cache dl utils ${LIB_HW_SECURE} diff --git a/src/AccCTR/src/common/util/error_code.h b/src/AccCTR/src/common/util/error_code.h index 04d26a57..e808779b 100644 --- a/src/AccCTR/src/common/util/error_code.h +++ b/src/AccCTR/src/common/util/error_code.h @@ -29,7 +29,20 @@ using CTRCode = enum : int { H_OUTPUT_TYPE_ERROR = 8, H_SCENE_ERROR = 9, H_MEMORY_ALLOC_ERROR = 10, - H_UNIQUE_UNINITIALIZED_ERROR = 11 + H_UNIQUE_UNINITIALIZED_ERROR = 11, + H_TABLE_NOT_EXIST = 12, + H_BUFFER_INVALID = 13, + H_INITIALIZER_INVALID = 14, + H_EXT_EMBEDDING_SIZE_INVALID = 15, + H_MAX_CACHESIZE_TOO_SMALL = 16, + H_HOST_VOCAB_SIZE_TOO_SMALL = 17, + H_THREAD_NUM_ERROR = 18, + H_TABLE_CREATE_DUPLICATE = 19, + H_ARG_NOT_EMPTY = 20, + H_SIZE_ZERO = 21, + H_TABLE_NAME_EMPTY = 22, + H_PREFILL_BUFFER_SIZE_INVALID = 23, + H_TABLE_NAME_TOO_LONG = 24, }; } } diff --git a/src/AccCTR/src/common/util/external_threader.h b/src/AccCTR/src/common/util/external_threader.h index 5a1132af..e6b723d7 100644 --- a/src/AccCTR/src/common/util/external_threader.h +++ b/src/AccCTR/src/common/util/external_threader.h @@ -20,11 +20,81 @@ limitations under the License. #include #include #include +#include +#include +#include +#include #include "singleton.h" using ExternalThread = void (*)(const std::vector> &tasks); namespace ock { +class ThreadPoolAsync { +public: + ThreadPoolAsync() : stop(false) {} + + ~ThreadPoolAsync() + { + { + std::lock_guard lock(taskMutex); + stop = true; + } + taskCv.notify_all(); + for (auto &t : workerThreads) { + t.join(); + } + } + + void SetNumThreads(int n) + { + if (n < 1) { + return; + } + + for (int i = 0; i < n; ++i) { + workerThreads.emplace_back(std::bind(&ThreadPoolAsync::WorkerThread, this)); + } + } + + template std::future AddTask(F &&f) + { + std::lock_guard lock(taskMutex); + + auto pt = std::make_unique>(f); + auto fut = pt->get_future(); + tasks.emplace(std::move(pt)); + taskCv.notify_one(); + return fut; + } + +private: + std::vector workerThreads; + std::queue>> tasks; + std::mutex taskMutex; + std::condition_variable taskCv; + volatile bool stop = false; + + void WorkerThread() + { + while (true) { + std::unique_ptr> task; + { + std::unique_lock lock(taskMutex); + while (tasks.empty() && !stop) { + taskCv.wait(lock); + } + if (stop) { + break; + } + task = std::move(tasks.front()); + tasks.pop(); + } + (*task)(); + } + } +}; + + class SimpleThreadPool { public: static void SyncRun(const std::vector> &tasks) diff --git a/src/AccCTR/src/embedding_cache/CMakeLists.txt b/src/AccCTR/src/embedding_cache/CMakeLists.txt new file mode 100644 index 00000000..e0278a6e --- /dev/null +++ b/src/AccCTR/src/embedding_cache/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +file(GLOB_RECURSE SRCS *.cpp *.h) + +add_library(embedding_cache OBJECT ${SRCS}) + +target_link_libraries(embedding_cache + -Wl,--start-group + -Wl,--end-group + ) + +target_include_directories(embedding_cache + PUBLIC + ${PROJECT_SOURCE_DIR}/src/common/util + ${PROJECT_SOURCE_DIR}/src/include) \ No newline at end of file diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp new file mode 100644 index 00000000..991307fd --- /dev/null +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -0,0 +1,379 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include + +#include "external_logger.h" +#include "cache_manager.h" + +using namespace EmbCache; +using namespace ock; +using namespace ock::ctr; + +int64_t EmbCache::INVALID_KEY = -1; + +int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo &embCacheInfo, + const std::vector &initializerInfos, int64_t invalidKey, uint64_t prefillBufferSize, + uint32_t refillThreadNum) +{ + int checkTableNameRet = CheckCreateTableName(embCacheInfo.tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (embCacheInfo.extEmbeddingSize == 0 || embCacheInfo.embeddingSize == 0 || embCacheInfo.vocabSize == 0 || + embCacheInfo.maxCacheSize == 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "size must be positive"); + return H_SIZE_ZERO; + } + + if (embCacheInfo.vocabSize < embCacheInfo.maxCacheSize) { + ExternalLogger::PrintLog(LogLevel::ERROR, "vocabSize must be greater than or equal to maxCacheSize"); + return H_HOST_VOCAB_SIZE_TOO_SMALL; + } + + auto om = offsetMappers.find(embCacheInfo.tableName); + auto embTable = embTables.find(embCacheInfo.tableName); + if (om != offsetMappers.end() || embTable != embTables.end()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "This table has already been created"); + return H_TABLE_CREATE_DUPLICATE; + } + + if (embCacheInfo.extEmbeddingSize % embCacheInfo.embeddingSize != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "extEmbeddingSize = embeddingSize + optimizerSize, " + "which is divisible by embeddingSize"); + return H_EXT_EMBEDDING_SIZE_INVALID; + } + + if (!CheckInitializer(embCacheInfo.extEmbeddingSize, initializerInfos)) { + return H_INITIALIZER_INVALID; + } + + if ((prefillBufferSize < 1) || (prefillBufferSize > embCacheInfo.vocabSize)) { + ExternalLogger::PrintLog(LogLevel::ERROR, "prefillBufferSize has to be between [1, hostVocabSize]"); + return H_PREFILL_BUFFER_SIZE_INVALID; + } + + if (!CheckValidThreadNum(refillThreadNum)) { + return H_THREAD_NUM_ERROR; + } + + uint32_t reserve = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; + if (!offsetMappers[embCacheInfo.tableName].Initialize(reserve, embCacheInfo.maxCacheSize)) { + offsetMappers.erase(embCacheInfo.tableName); + return H_MEMORY_ALLOC_ERROR; + } + + EmbPoolParam embPoolParam{ prefillBufferSize, refillThreadNum }; + + if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo.extEmbeddingSize, embCacheInfo.vocabSize, reserve, + initializerInfos, embPoolParam)) { + offsetMappers.erase(embCacheInfo.tableName); + embTables.erase(embCacheInfo.tableName); + return H_MEMORY_ALLOC_ERROR; + } + + embCacheInfos.insert({ embCacheInfo.tableName, embCacheInfo }); + INVALID_KEY = invalidKey; + return H_OK; +} + +int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::vector &keys, + KeyOffsetPair &swapInKoPair, KeyOffsetPair &swapOutKoPair) +{ + int checkRet = CheckGetSwapPairsAndKey2Offset(tableName, swapInKoPair, swapOutKoPair); + if (checkRet != H_OK) { + return checkRet; + } + return offsetMappers[tableName].GetSwapPairsAndKey2Offset(keys, swapInKoPair.first, swapInKoPair.second, + swapOutKoPair.first, swapOutKoPair.second); +} + +int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + if (embAddr == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); + return H_ADDRESS_NULL; + } + + return embTables[tableName].Gather(reinterpret_cast(embAddr), keys, threadNum); +} + +int EmbCacheManagerImpl::EmbeddingLookupAddrs(std::string tableName, const std::vector &keys, + std::vector &addrs, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + return embTables[tableName].GatherAddrs(keys, addrs, threadNum); +} + + +// 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 +int EmbCacheManagerImpl::EmbeddingLookupAndRemove(std::string tableName, const std::vector &keys, + float *embAddr, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + if (embAddr == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); + return H_ADDRESS_NULL; + } + + return embTables[tableName].GatherAndRemove(reinterpret_cast(embAddr), keys, threadNum); +} + +int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (!CheckValidThreadNum(threadNum)) { // 检查thread是否小于核数 + return H_THREAD_NUM_ERROR; + } + + if (keys.empty()) { + return H_OK; + } + + if (embAddr == nullptr) { // 检查embAddr是不是空指针 + ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); + return H_ADDRESS_NULL; + } + + return embTables[tableName].Scatter(reinterpret_cast(embAddr), keys, threadNum); +} + +int EmbCacheManagerImpl::EmbeddingRemove(std::string tableName, const std::vector &keys, uint32_t threadNum) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + if (keys.empty()) { + return H_OK; + } + + return embTables[tableName].RemoveByKeys(keys, threadNum); +} + +int EmbCacheManagerImpl::RemoveEmbsByKeys(std::string tableName, const std::vector &keys) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + auto om = offsetMappers.find(tableName); + auto embTable = embTables.find(tableName); + for (auto key : keys) { + if (key == static_cast(INVALID_KEY)) { + ExternalLogger::PrintLog(LogLevel::WARN, "Try to evict invalid key"); + continue; + } + om->second.Remove(key); + embTable->second.Remove(key); + } + return H_OK; +} + +int EmbCacheManagerImpl::GetEmbTableNames(std::vector &allTableNames) +{ + if (!allTableNames.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "allTableNames should be empty"); + return H_ARG_NOT_EMPTY; + } + allTableNames.reserve(embTables.size()); + for (auto &embTable : embTables) { + allTableNames.emplace_back(embTable.first); + } + return H_OK; +} + +int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(std::string tableName, + std::vector> &koVec) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + OffsetMapper &om = offsetMappers[tableName]; + koVec = om.ExportSortedKVPairs(); + return H_OK; +} + +int EmbCacheManagerImpl::Serialize(std::string tableName, std::vector &buffer) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + buffer = embTables[tableName].Serialize(); + return H_OK; +} + +int EmbCacheManagerImpl::Deserialize(std::string tableName, const std::vector &buffer) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + if (!embTables[tableName].Deserialize(buffer)) { + return H_BUFFER_INVALID; + } + return H_OK; +} + +void EmbCacheManagerImpl::Destroy() +{ + for (auto it = offsetMappers.begin(); it != offsetMappers.end(); it++) { + it->second.UnInitialize(); + } + for (auto it = embTables.begin(); it != embTables.end(); it++) { + it->second.UnInitialize(); + } + embCacheInfos.clear(); + offsetMappers.clear(); + embTables.clear(); +} + +int EmbCacheManagerImpl::CheckValidTableName(std::string tableName) +{ + if (tableName.size() > TABLE_NAME_MAX_SIZE) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); + return H_TABLE_NAME_TOO_LONG; + } + auto om = offsetMappers.find(tableName); + auto embTable = embTables.find(tableName); + if (om == offsetMappers.end() || embTable == embTables.end()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "can not find table"); + return H_TABLE_NOT_EXIST; + } + return H_OK; +} + +bool EmbCacheManagerImpl::CheckInitializer(uint32_t extEmbSize, std::vector initializerInfos) +{ + std::sort(initializerInfos.begin(), initializerInfos.end(), + [](const auto &u, const auto &v) { return u.start < v.start; }); + uint32_t cur_pos = 0; + for (const auto &info : initializerInfos) { + if (info.initializer == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "initializer is nullptr"); + return false; + } + if (info.start != cur_pos) { + ExternalLogger::PrintLog(LogLevel::ERROR, "Initializers got coverage problems"); + return false; + } + cur_pos += info.len; + } + // 最后判断 + if (cur_pos != extEmbSize) { + ExternalLogger::PrintLog(LogLevel::ERROR, "Initializers got coverage problems"); + return false; + } + return true; +} + +bool EmbCacheManagerImpl::CheckValidThreadNum(uint32_t threadNum) +{ + uint32_t processCoreNum = std::thread::hardware_concurrency(); + if (threadNum > processCoreNum) { + ExternalLogger::PrintLog(LogLevel::ERROR, "ThreadNum can not larger than cpu core num"); + return false; + } + + if (threadNum == 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "ThreadNum can not be zero"); + return false; + } + return true; +} + +int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair &swapInKoPair, + const KeyOffsetPair &swapOutKoPair) +{ + if (!swapInKoPair.first.empty() || !swapInKoPair.second.empty() || !swapOutKoPair.first.empty() || + !swapOutKoPair.second.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "koPair should be empty"); + return H_ARG_NOT_EMPTY; + } + + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + return H_OK; +} + +int EmbCacheManagerImpl::CheckCreateTableName(const std::string &tableName) +{ + if (tableName.empty()) { + ExternalLogger::PrintLog(LogLevel::ERROR, "tableName can not be empty"); + return H_TABLE_NAME_EMPTY; + } + + if (tableName.size() > TABLE_NAME_MAX_SIZE) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); + return H_TABLE_NAME_TOO_LONG; + } + return H_OK; +} + +uint32_t EmbCacheManagerImpl::GetUsage(const std::string &tableName) +{ + return offsetMappers[tableName].GetUsage(); +} diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h new file mode 100644 index 00000000..314d0572 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h @@ -0,0 +1,86 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef EMBEDDING_CACHE_MANAGER_H +#define EMBEDDING_CACHE_MANAGER_H + +#include +#include +#include +#include + +#include "embedding_cache.h" +#include "error_code.h" +#include "offset_mapper/offset_mapper.h" +#include "embedding_local_table/emb_local_table.h" + +namespace EmbCache { +class EmbCacheManagerImpl : public EmbCacheManager { +public: + EmbCacheManagerImpl() = default; + + ~EmbCacheManagerImpl() override = default; + + int CreateCacheForTable(const EmbCacheInfo &embCacheInfo, const std::vector &initializerInfos, + int64_t invalidKey, uint64_t prefillBufferSize, uint32_t refillThreadNum) override; + + int GetSwapPairsAndKey2Offset(std::string tableName, std::vector &keys, KeyOffsetPair &swapInKoPair, + KeyOffsetPair &swapOutKoPair) override; + + int EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum) override; + + int EmbeddingLookupAddrs(std::string tableName, const std::vector &keys, std::vector &addrs, + uint32_t threadNum) override; + + int EmbeddingUpdate(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum) override; + + int EmbeddingRemove(std::string tableName, const std::vector &keys, uint32_t threadNum) override; + + int EmbeddingLookupAndRemove(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum) override; + + int RemoveEmbsByKeys(std::string tableName, const std::vector &keys) override; + + int GetEmbTableNames(std::vector &allTableNames) override; + + int ExportDeviceKeyOffsetPairs(std::string tableName, std::vector> &koVec) override; + + int Serialize(std::string tableName, std::vector &buffer) override; + + int Deserialize(std::string tableName, const std::vector &buffer) override; + + void Destroy() override; + + uint32_t GetUsage(const std::string &tableName) override; + +private: + std::map embCacheInfos; + std::map offsetMappers; + std::map embTables; + + int CheckValidTableName(std::string tableName); + + bool CheckInitializer(uint32_t extEmbSize, std::vector initializerInfos); + + bool CheckValidThreadNum(uint32_t threadNum); + + int CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair &swapInKoPair, + const KeyOffsetPair &swapOutKoPair); + + int CheckCreateTableName(const std::string &tableName); +}; +} +#endif // EMBEDDING_CACHE_MANAGER_H diff --git a/src/AccCTR/src/embedding_cache/common.h b/src/AccCTR/src/embedding_cache/common.h new file mode 100644 index 00000000..33a975c3 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/common.h @@ -0,0 +1,159 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_COMMON_H +#define MXREC_COMMON_H + +#include +#include +#include + +#ifndef HM_UNLIKELY +#define HM_UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif + +#ifndef HM_LIKELY +#define HM_LIKELY(x) __builtin_expect(!!(x), 1) +#endif + +namespace EmbCache { + +class LimitedSet { +public: + struct Node { + uint64_t value; + Node *prev, *next; + Node(uint64_t val = -1) : value(val), prev(nullptr), next(nullptr) {} + }; + + LimitedSet(uint64_t maxRange) : head(new Node(-1)), tail(new Node(-1)) + { + nodes.resize(maxRange); + for (auto &node : nodes) { + node = new Node(-1); + } + head->next = tail; + tail->prev = head; + } + + ~LimitedSet() + { + for (auto &node : nodes) { + delete node; + } + delete head; + delete tail; + } + + void insert(uint64_t value) + { + if (nodes[value]->value == value) { + return; + } + Node *node = nodes[value]; + node->value = value; + Node *next = head->next; + node->next = next; + node->prev = head; + head->next = node; + next->prev = node; + } + + void remove(uint64_t value) + { + if (nodes[value]->value != value) { + return; + } + Node *node = nodes[value]; + node->prev->next = node->next; + node->next->prev = node->prev; + node->value = -1; + } + + bool find(uint64_t value) + { + return nodes[value]->value == value; + } + + class Iterator { + public: + Iterator(Node *node) : current(node) {} + bool operator != (const Iterator &other) const + { + return current != other.current; + } + const uint64_t &operator*() const + { + return current->value; + } + Iterator &operator ++ () + { + current = current->next; + return *this; + } + + private: + Node *current; + }; + + Iterator begin() + { + return { head->next }; + } + + Iterator end() + { + return { tail }; + } + +private: + Node *head; + Node *tail; + std::vector nodes; +}; + +enum class FkvState { + FKV_EXIST, + FKV_NOT_EXIST, + FKV_KEY_CONFLICT, + FKV_BEFORE_PUT_FUNC_FAIL, + FKV_BEFORE_REMOVE_FUNC_FAIL, + FKV_NO_SPACE, + FKV_FAIL, +}; + +enum class BeforePutFuncState { + BEFORE_SUCCESS, + BEFORE_NO_SPACE, + BEFORE_FAIL, +}; + +enum class BeforeRemoveFuncState { + BEFORE_SUCCESS, + BEFORE_FAIL, +}; + +extern int64_t INVALID_KEY; +constexpr uint64_t TABLE_NAME_MAX_SIZE = 1024; +const uint32_t VOCAB_CACHE_RATIO = 15; +constexpr float NORMAL_MEAN_MAX = 1e9; +constexpr float NORMAL_MEAN_MIN = -1e9; +constexpr float NORMAL_STDDEV_MAX = 100; +constexpr float NORMAL_STDDEV_MIN = 0; +constexpr float CONSTANT_VALUE_MAX = 1e9; +constexpr float CONSTANT_VALUE_MIN = -1e9; +constexpr float INIT_K_MAX = 10000; +constexpr float INIT_K_MIN = -10000; +} +#endif // MXREC_COMMON_H diff --git a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp new file mode 100644 index 00000000..f0b050d6 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp @@ -0,0 +1,385 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "emb_local_table.h" +#include +#include +#include "securec.h" +#include "error_code.h" + +using namespace std; +using namespace EmbCache; +using namespace ock; +using namespace ock::ctr; + +bool EmbLocalTable::Initialize(uint32_t extEmbeddingSize, uint32_t hostVocabSize, uint32_t reserve, + const std::vector &initializerInfos, const EmbPoolParam &embPoolParam) +{ + emExpendMemInfo = make_shared(embPoolParam.prefillBufferSize, initializerInfos, + extEmbeddingSize, hostVocabSize, embPoolParam.refillThreadNum); + return embMap.Initialize(reserve, hostVocabSize, emExpendMemInfo); +} + +void EmbLocalTable::UnInitialize() +{ + embMap.UnInitialize(); +} + +int EmbLocalTable::FindAndPutIfNotFound(uint64_t key, uint64_t &value, bool init) +{ + FkvState ret = embMap.FindAndPutIfNotFound(key, value, init); + if (ret == FkvState::FKV_FAIL) { + return H_ERROR; + } + if (ret == FkvState::FKV_BEFORE_PUT_FUNC_FAIL) { + return H_MEMORY_ALLOC_ERROR; + } + if (ret == FkvState::FKV_NO_SPACE) { + return H_HOST_VOCAB_SIZE_TOO_SMALL; + } + return H_OK; +} + +bool EmbLocalTable::Remove(uint64_t key) +{ + return embMap.Remove(key) != FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; +} + +int EmbLocalTable::RemoveByKeys(const std::vector &keys, uint32_t threadNum) +{ + if (threadNum == 1) { + for (uint64_t key : keys) { + Remove(key); + } + return H_OK; + } + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + Remove(keys[i]); + } + }); + } + for (auto &t : threads) { + t.join(); + } + return H_OK; +} + +int EmbLocalTable::OneThreadHandle(uint64_t startAddr, const std::vector &keys, bool isGather) +{ + for (uint64_t i = 0; i < keys.size(); i++) { + uint64_t embAddr; + int ret = FindAndPutIfNotFound(keys[i], embAddr); + if (ret != H_OK) { + return ret; + } + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + if (isGather) { + auto rc = memcpy_s(reinterpret_cast(addr), memSize, reinterpret_cast(embAddr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "gather memcpy_s failed... dstSize: " + std::to_string(memSize)); + return H_COPY_ERROR; + } + } else { + auto rc = memcpy_s(reinterpret_cast(embAddr), memSize, // 按顺序把新的embedding拷贝到对应地址中 + reinterpret_cast(addr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, + "scatter memcpy_s failed... dstSize: " + std::to_string(memSize)); + return H_COPY_ERROR; + } + } + } + + return H_OK; +} + +int EmbLocalTable::Gather(uint64_t startAddr, const vector &keys, uint32_t threadNum) +{ + if (threadNum == 1) { + return OneThreadHandle(startAddr, keys, true); + } + + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + int ret = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + uint64_t embAddr; + int temp_ret = FindAndPutIfNotFound(keys[i], embAddr); + if (temp_ret != H_OK) { + ret = temp_ret; + return; + } + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto rc = memcpy_s(reinterpret_cast(addr), memSize, reinterpret_cast(embAddr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + ret = H_COPY_ERROR; + return; + } + } + }); + } + for (auto &t : threads) { + t.join(); + } + return ret; +} + +int EmbLocalTable::GatherAddrs(const std::vector &keys, std::vector &addrs, uint32_t threadNum) +{ + if (threadNum == 1) { + addrs.resize(keys.size()); + for (uint64_t i = 0; i < keys.size(); i++) { + int temp_ret = FindAndPutIfNotFound(keys[i], reinterpret_cast(addrs[i])); + if (temp_ret != H_OK) { + return temp_ret; + } + } + return H_OK; + } + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + addrs.resize(keys.size()); + + vector threads(threadNum); + int ret = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + int temp_ret = FindAndPutIfNotFound(keys[i], reinterpret_cast(addrs[i])); + if (temp_ret != H_OK) { + ret = temp_ret; + return; + } + } + }); + } + for (auto &t : threads) { + t.join(); + } + return ret; +} + +// 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 +int EmbLocalTable::GatherAndRemove(uint64_t startAddr, const vector &keys, uint32_t threadNum) +{ + if (threadNum == 1) { + for (uint64_t i = 0; i < keys.size(); i++) { + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto ret = embMap.FindAndRemoveIfFound(keys[i], addr); // 如果找到了就拷贝出来然后把key删了 + if (ret == FkvState::FKV_NOT_EXIST) { // 没找到key,给一个新的初始化值并且不需要存入key + auto *embAddr = reinterpret_cast(addr); + for (const auto &initializerInfo : emExpendMemInfo->initializerInfos) { + initializerInfo.initializer->GenerateData(embAddr); + } + } else if (ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + return H_COPY_ERROR; + } + } + return H_OK; + } + + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + int retVal = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto ret = embMap.FindAndRemoveIfFound(keys[i], addr); // 如果找到了就拷贝出来然后把key删了 + if (ret == FkvState::FKV_NOT_EXIST) { // 没找到key,给一个新的初始化值并且不需要存入key + auto *embAddr = reinterpret_cast(addr); + for (const auto &initializerInfo : emExpendMemInfo->initializerInfos) { + initializerInfo.initializer->GenerateData(embAddr); + } + } else if (ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + retVal = H_COPY_ERROR; + return; + } + } + }); + } + for (auto &t : threads) { + t.join(); + } + return retVal; +} + +int EmbLocalTable::Scatter(const uint64_t startAddr, const vector &keys, uint32_t threadNum) +{ + if (threadNum == 1) { // 单线程版本 + return OneThreadHandle(startAddr, keys, false); + } + + // 多线程版本 + // 每个线程处理[start[threadId],start[threadId+1])这个区间的key + uint32_t m = keys.size() % threadNum; + vector start(threadNum + 1); + // 前keys.size()%threadNum个线程向上取整 + for (uint32_t threadId = 0; threadId < m; threadId++) { + start[threadId] = ((keys.size() + threadNum - 1) / threadNum) * threadId; + } + // 后面的向下取整 + for (uint32_t threadId = m; threadId <= threadNum; threadId++) { + start[threadId] = (keys.size() / threadNum) * threadId + m; + } + + vector threads(threadNum); + int ret = H_OK; + for (uint32_t threadId = 0; threadId < threadNum; threadId++) { + threads[threadId] = thread([&, threadId] { + for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { + uint64_t embAddr; + int temp_ret = FindAndPutIfNotFound(keys[i], embAddr); // 获取每个key的embedding对应首地址 + if (temp_ret != H_OK) { + ret = temp_ret; + return; + } + uint64_t memSize = emExpendMemInfo->extEmbeddingSize * sizeof(float); + auto addr = startAddr + i * memSize; + auto rc = memcpy_s(reinterpret_cast(embAddr), memSize, // 按顺序把新的embedding拷贝到对应地址中 + reinterpret_cast(addr), memSize); + if (rc != 0) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memcpy_s failed... dstSize: " + std::to_string(memSize)); + ret = H_COPY_ERROR; + return; + } + } + }); + } + for (auto &t : threads) { + t.join(); + } + return ret; +} + +// 导出存储的所有kv对 +vector> EmbLocalTable::ExportVec() +{ + return embMap.ExportVec(); +} + +template void EmbLocalTable::insertData(vector &buffer, T &data) +{ + buffer.insert(buffer.end(), (char *)&data, (char *)&data + sizeof(data)); +} + +template bool EmbLocalTable::getData(const vector &buffer, T &data, uint64_t &i) +{ + if (i + sizeof(T) > buffer.size()) { + return false; + } + data = *reinterpret_cast(&buffer[i]); + i += sizeof(T); + return true; +} + +// 把所存储的key-embedding信息序列化 +vector EmbLocalTable::Serialize() +{ + vector buffer; + vector> kvVec = ExportVec(); + + for (auto &p : kvVec) { + uint64_t key = p.first; + uint64_t value = p.second; + insertData(buffer, key); + auto *addr = reinterpret_cast(value); + buffer.insert(buffer.end(), reinterpret_cast(addr), + reinterpret_cast((addr + emExpendMemInfo->extEmbeddingSize))); + } + return buffer; +} + +// 反序列化key-embedding,存进map +bool EmbLocalTable::Deserialize(const vector &buffer) +{ + uint64_t i = 0; + while (i < buffer.size()) { + uint64_t key; + if (!getData(buffer, key, i)) { + ExternalLogger::PrintLog(LogLevel::ERROR, "get data failed!"); + return false; + } + uint64_t value = 0; + if (FindAndPutIfNotFound(key, value, false) != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "FindAndPutIfNotFound failed!"); + return false; + } + + auto *addr = reinterpret_cast(value); + for (uint32_t j = 0; j < emExpendMemInfo->extEmbeddingSize; j++) { + if (!getData(buffer, addr[j], i)) { + ExternalLogger::PrintLog(LogLevel::ERROR, "get data failed!"); + return false; + } + } + } + return true; +} \ No newline at end of file diff --git a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h new file mode 100644 index 00000000..eda92698 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.h @@ -0,0 +1,69 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef EMB_LOCAL_TABLE_H +#define EMB_LOCAL_TABLE_H + +#include +#include +#include +#include "offset_mapper/address_mapper.h" + +namespace EmbCache { +struct EmbPoolParam { + uint64_t prefillBufferSize; + uint32_t refillThreadNum; +}; + +class EmbLocalTable { +public: + EmbLocalTable() = default; + + ~EmbLocalTable() = default; + + bool Initialize(uint32_t extEmbeddingSize, uint32_t hostVocabSize, uint32_t reserve, + const std::vector &initializerInfos, const EmbPoolParam &embPoolParam); + + void UnInitialize(); + + int FindAndPutIfNotFound(uint64_t key, uint64_t &value, bool init = true); + + bool Remove(uint64_t key); + + int RemoveByKeys(const std::vector &keys, uint32_t threadNum); + + int Gather(uint64_t startAddr, const std::vector &keys, uint32_t threadNum); + + int GatherAddrs(const std::vector &keys, std::vector &addrs, uint32_t threadNum); + + int Scatter(uint64_t startAddr, const std::vector &keys, uint32_t threadNum); + + int OneThreadHandle(uint64_t startAddr, const std::vector &keys, bool isGather); + + int GatherAndRemove(uint64_t startAddr, const std::vector &keys, uint32_t threadNum); + + std::vector> ExportVec(); + + std::vector Serialize(); + + bool Deserialize(const std::vector &buffer); + +private: + std::shared_ptr emExpendMemInfo; + AddressMapper embMap; + template void insertData(std::vector &buffer, T &data); + template bool getData(const std::vector &buffer, T &data, uint64_t &i); +}; +} +#endif // EMB_LOCAL_TABLE_H diff --git a/src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp new file mode 100644 index 00000000..2cd9f267 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/constant_initializer/constant_initializer.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "embedding_cache.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +using namespace std; +using namespace EmbCache; +using namespace ock; + +ConstantInitializer::ConstantInitializer(uint32_t start, uint32_t len, float value, float initK) + : start(start), len(len) +{ + if (value > CONSTANT_VALUE_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant value is greater than " + + std::to_string(CONSTANT_VALUE_MAX) + ", and will use " + std::to_string(CONSTANT_VALUE_MAX) + "."); + constantValue = CONSTANT_VALUE_MAX; + } else if (value < CONSTANT_VALUE_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant value is less than " + std::to_string(CONSTANT_VALUE_MIN) + + ", and will use " + std::to_string(CONSTANT_VALUE_MIN) + "."); + constantValue = CONSTANT_VALUE_MIN; + } else { + constantValue = value; + } + if (initK > INIT_K_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant initK is greater than " + std::to_string(INIT_K_MAX) + + ", and will use " + std::to_string(INIT_K_MAX) + "."); + initParam = INIT_K_MAX; + } else if (initK < INIT_K_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "constant initK is less than " + std::to_string(INIT_K_MIN) + + ", and will use " + std::to_string(INIT_K_MIN) + "."); + initParam = INIT_K_MIN; + } else { + initParam = initK; + } +} + +void ConstantInitializer::GenerateData(float * const emb) +{ + if (len == 0) { + return; + } + std::fill_n(emb + start, len, initParam * constantValue); +} diff --git a/src/AccCTR/src/embedding_cache/initializer/initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/initializer.cpp new file mode 100644 index 00000000..887aaee0 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/initializer.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include + +#include "external_logger.h" +#include "embedding_cache.h" + +using namespace EmbCache; + +ConstantInitializerInfo::ConstantInitializerInfo(float constantValue, float initK) + : constantValue(constantValue), initK(initK) +{} + +NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, uint32_t seed, float initK) + : mean(mean), stddev(stddev), seed(seed), initK(initK) +{} + +InitializerInfo::InitializerInfo(std::string &name, uint32_t start, uint32_t len, + ConstantInitializerInfo constantInitializerInfo) + : name(name), start(start), len(len), constantInitializerInfo(constantInitializerInfo) +{ + if (name == "constant_initializer") { + initializerType = InitializerType::CONSTANT; + initializer = std::make_shared(start, len, constantInitializerInfo.constantValue, + constantInitializerInfo.initK); + } else { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Invalid Initializer Type."); + } +} + +InitializerInfo::InitializerInfo(std::string &name, uint32_t start, uint32_t len, + NormalInitializerInfo normalInitializerInfo) + : name(name), start(start), len(len), normalInitializerInfo(normalInitializerInfo) +{ + if (name == "truncated_normal_initializer") { + initializerType = InitializerType::TRUNCATED_NORMAL; + initializer = std::make_shared(start, len, normalInitializerInfo); + } else if (name == "random_normal_initializer") { + initializerType = InitializerType::RANDOM_NORMAL; + initializer = std::make_shared(start, len, normalInitializerInfo); + } else { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Invalid Initializer Type."); + } +} diff --git a/src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp new file mode 100644 index 00000000..979fd212 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -0,0 +1,72 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include +#include "embedding_cache.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +using namespace EmbCache; +using namespace ock; + +RandomNormalInitializer::RandomNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo &initInfo) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) +{ + // 校验stddev mean及initK值范围 + if (initInfo.mean > NORMAL_MEAN_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal mean param is greater than " + + std::to_string(NORMAL_MEAN_MAX) + ", and will use " + std::to_string(NORMAL_MEAN_MAX) + "."); + mean = NORMAL_MEAN_MAX; + } else if (initInfo.mean < NORMAL_MEAN_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal mean param is less than " + + std::to_string(NORMAL_MEAN_MIN) + ", and will use " + std::to_string(NORMAL_MEAN_MIN) + "."); + mean = NORMAL_MEAN_MIN; + } else { + mean = initInfo.mean; + } + if (initInfo.stddev > NORMAL_STDDEV_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal stddev param is greater than " + + std::to_string(NORMAL_STDDEV_MAX) + ", and will use " + std::to_string(NORMAL_STDDEV_MAX) + "."); + stddev = NORMAL_STDDEV_MAX; + } else if (initInfo.stddev < NORMAL_STDDEV_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal stddev param is less than " + + std::to_string(NORMAL_STDDEV_MIN) + ", and will use " + std::to_string(NORMAL_STDDEV_MIN) + "."); + stddev = NORMAL_STDDEV_MIN; + } else { + stddev = initInfo.stddev; + } + if (initInfo.initK > INIT_K_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal initK is greater than " + std::to_string(INIT_K_MAX) + + ", and will use " + std::to_string(INIT_K_MAX) + "."); + initParam = INIT_K_MAX; + } else if (initInfo.initK < INIT_K_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "random normal initK is less than " + std::to_string(INIT_K_MIN) + + ", and will use " + std::to_string(INIT_K_MIN) + "."); + initParam = INIT_K_MIN; + } else { + initParam = initInfo.initK; + } + + generator = std::default_random_engine(seed); + distribution = std::normal_distribution(mean, stddev); +} + +void RandomNormalInitializer::GenerateData(float * const emb) +{ + if (len == 0) { + return; + } + std::generate_n(emb + start, len, [this]() { return initParam * distribution(generator); }); +} \ No newline at end of file diff --git a/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp new file mode 100644 index 00000000..c2441466 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -0,0 +1,80 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include "embedding_cache.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +using namespace EmbCache; +using namespace ock; + +TruncatedNormalInitializer::TruncatedNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo &initInfo) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) +{ + // 校验stddev mean及initK值范围 + if (initInfo.mean > NORMAL_MEAN_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal mean param is greater than " + + std::to_string(NORMAL_MEAN_MAX) + ", and will use " + std::to_string(NORMAL_MEAN_MAX) + "."); + mean = NORMAL_MEAN_MAX; + } else if (initInfo.mean < NORMAL_MEAN_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal mean param is less than " + + std::to_string(NORMAL_MEAN_MIN) + ", and will use " + std::to_string(NORMAL_MEAN_MIN) + "."); + mean = NORMAL_MEAN_MIN; + } else { + mean = initInfo.mean; + } + if (initInfo.stddev > NORMAL_STDDEV_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal stddev param is greater than " + + std::to_string(NORMAL_STDDEV_MAX) + ", and will use " + std::to_string(NORMAL_STDDEV_MAX) + "."); + stddev = NORMAL_STDDEV_MAX; + } else if (initInfo.stddev < NORMAL_STDDEV_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal stddev param is less than " + + std::to_string(NORMAL_STDDEV_MIN) + ", and will use " + std::to_string(NORMAL_STDDEV_MIN) + "."); + stddev = NORMAL_STDDEV_MIN; + } else { + stddev = initInfo.stddev; + } + if (initInfo.initK > INIT_K_MAX) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal initK is greater than " + + std::to_string(INIT_K_MAX) + ", and will use " + std::to_string(INIT_K_MAX) + "."); + initParam = INIT_K_MAX; + } else if (initInfo.initK < INIT_K_MIN) { + ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal initK is less than " + std::to_string(INIT_K_MIN) + + ", and will use " + std::to_string(INIT_K_MIN) + "."); + initParam = INIT_K_MIN; + } else { + initParam = initInfo.initK; + } + + generator = std::default_random_engine(seed); + distribution = std::normal_distribution(mean, stddev); + minBound = initParam * (mean - static_cast(boundNum) * stddev); + maxBound = initParam * (mean + static_cast(boundNum) * stddev); +} + + +void TruncatedNormalInitializer::GenerateData(float * const emb) +{ + if (len == 0) { + return; + } + std::generate_n(emb + start, len, [this]() { + float tmp = initParam * distribution(generator); + while (tmp < minBound || tmp > maxBound) { + tmp = initParam * distribution(generator); + } + return tmp; + }); +} diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h new file mode 100644 index 00000000..8b6eefae --- /dev/null +++ b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h @@ -0,0 +1,311 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_FASTER_QUERY_H +#define MXREC_FASTER_QUERY_H + +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" +#include "embedding_cache.h" +#include "offset_mapper/mapper_base.h" + +namespace EmbCache { +using EmExpandMemUint = struct em_expand_memory_uint_ { + uint64_t address = 0; + uint64_t capacity = 0; + uint64_t leftCapacity = 0; + + em_expand_memory_uint_() = default; + + em_expand_memory_uint_(uint64_t a, uint64_t c) : address(a), capacity(c), leftCapacity(c) {} +}; + +template class QWithLock { +public: + bool pop(T &ele) + { + std::lock_guard lk(mut); + if (dataQ.empty()) { + return false; + } + ele = dataQ.front(); + dataQ.pop(); + return true; + } + + void push(const T &ele) + { + std::lock_guard lk(mut); + dataQ.push(ele); + } + + uint64_t GetLength() + { + std::lock_guard lk(mut); + return dataQ.size(); + } + +private: + std::mutex mut; + std::queue dataQ; +}; + +class AutoRefillEmbeddingMemoryPool { +public: + std::vector expandedMemory; + uint32_t extEmbeddingSize; + std::vector initializerInfos; + + AutoRefillEmbeddingMemoryPool(uint64_t bufferSize, std::vector initInfos, uint32_t extEmbSize, + uint64_t hostVocabSize, uint32_t refillThreadNum = 1) + : extEmbeddingSize(extEmbSize), + initializerInfos(std::move(initInfos)), + maxBufferSize(bufferSize), + totalLeftVocabSize(hostVocabSize), + numThreads(refillThreadNum) + { + itemSize = extEmbeddingSize * sizeof(float); + maxExpandSize = maxBufferSize * itemSize; + for (uint32_t i = 0; i < numThreads; i++) { + producerThreads.emplace_back([this] { ProducerWorker(); }); + } + } + + ~AutoRefillEmbeddingMemoryPool() + { + { + std::lock_guard lock(producerMutex); + stop = true; + } + producerCv.notify_all(); + fullCv.notify_all(); + for (auto &t : producerThreads) { + t.join(); + } + } + + void Stop() + { + std::lock_guard lock(producerMutex); + stop = true; + producerCv.notify_all(); + fullCv.notify_all(); + } + + BeforePutFuncState GetNewValueToBeInserted(uint64_t &value, bool init = true, uint32_t maxRetry = 1000) + { + if (init) { + for (uint32_t i = 0; i < maxRetry; i++) { + if (BufferBin.pop(value)) { + producerCv.notify_one(); + return BeforePutFuncState::BEFORE_SUCCESS; + }; + producerCv.notify_one(); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "Failed to get new address for embedding, it is likely due to refill thread memory allocation failure " + "or max retry has been reached. Please check for memory alloc error or increase refill thread num!"); + return BeforePutFuncState::BEFORE_FAIL; + } + + if (!recycleBin.pop(value)) { + if (!GetNewAddr(value)) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to get new address for embedding, " + "memory allocation failure!"); + return BeforePutFuncState::BEFORE_FAIL; + } + } + + return BeforePutFuncState::BEFORE_SUCCESS; + } + + void GetValueToBeRecycled(uint64_t value) + { + std::lock_guard lock(producerMutex); + recycleBin.push(value); + full = false; + fullCv.notify_one(); + } + +private: + uint64_t maxBufferSize; + uint64_t totalLeftVocabSize; + uint32_t numThreads; + std::atomic currBufferSize{ 0 }; + volatile bool stop = false; + volatile std::atomic full = false; + std::mutex producerMutex; + std::mutex getAddrMutex; + std::condition_variable producerCv; + std::condition_variable fullCv; + QWithLock BufferBin; + QWithLock recycleBin; + std::vector producerThreads; + EmExpandMemUint currentMemoryUint{}; + uint64_t dynamicExpandRatio = 2; + uint64_t maxExpandSize; + uint64_t itemSize; + + bool GetNewAddr(uint64_t &newAddr) + { + std::lock_guard lg(getAddrMutex); + if (HM_UNLIKELY(currentMemoryUint.leftCapacity <= 0)) { + /* need to expand memory */ + uint64_t maxSize = std::min(maxExpandSize, totalLeftVocabSize * itemSize); + uint64_t newSize = currentMemoryUint.capacity ? + std::min(currentMemoryUint.capacity * dynamicExpandRatio, maxSize) : + itemSize; + if (newSize == 0) { + if (recycleBin.GetLength() == 0) { + full = true; + } + return false; + } + auto newAddress = (uint64_t)malloc(newSize); + if (newAddress == 0) { + ock::ExternalLogger::PrintLog(ock::LogLevel::WARN, "Refill thread allocate memory failed!"); + return false; + } + expandedMemory.emplace_back(newAddress, newSize); + currentMemoryUint.address = newAddress; + currentMemoryUint.capacity = newSize; + currentMemoryUint.leftCapacity = newSize; + totalLeftVocabSize -= newSize / itemSize; + } + newAddr = currentMemoryUint.address + currentMemoryUint.capacity - currentMemoryUint.leftCapacity; + currentMemoryUint.leftCapacity -= itemSize; + return true; + } + + void Produce() + { + uint64_t newAddr; + if (!recycleBin.pop(newAddr)) { + if (!GetNewAddr(newAddr)) { + return; + } + } + GenerateData(newAddr); + BufferBin.push(newAddr); + } + + void GenerateData(const uint64_t &addr) + { + auto *embAddr = reinterpret_cast(addr); + for (const auto &initializerInfo : initializerInfos) { + initializerInfo.initializer->GenerateData(embAddr); + } + } + + void ProducerWorker() + { + std::unique_lock lock(producerMutex); + while (!stop) { + if (BufferBin.GetLength() < maxBufferSize && !full) { + Produce(); + } else if (!full) { + producerCv.wait(lock); + } else { + fullCv.wait(lock); + } + } + } +}; + +class AddressMapper : public MapperBase { +public: + AddressMapper() = default; + + ~AddressMapper() = default; + + bool Initialize(uint32_t reserve, uint32_t vocabSize, std::shared_ptr expendInfoPtr) + { + hostVocabSize = vocabSize; + emExpendMemInfoPtr = expendInfoPtr; + return MapperBase::Initialize(reserve); + } + + void UnInitialize() override + { + emExpendMemInfoPtr->Stop(); + FreeExpandedMemory(); + MapperBase::UnInitialize(); + } + + FkvState Remove(uint64_t key) + { + return MapperBase::Remove(key, [&](uint64_t value) { + emExpendMemInfoPtr->GetValueToBeRecycled(value); + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + } + + FkvState FindAndPutIfNotFound(uint64_t key, uint64_t &value, bool init = true) + { + FkvState ret = MapperBase::FindAndPutIfNotFound(key, value, [&]() { + if (HM_UNLIKELY(current_size.load() >= hostVocabSize)) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "host does not have enough space"); + return BeforePutFuncState::BEFORE_NO_SPACE; + } + return emExpendMemInfoPtr->GetNewValueToBeInserted(value, init); + }); + if (ret == FkvState::FKV_FAIL) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "FindAndPutIfNotFound failed!"); + return ret; + } + if (ret == FkvState::FKV_BEFORE_PUT_FUNC_FAIL) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "malloc failed"); + return ret; + } + return ret; + } + + // 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 + FkvState FindAndRemoveIfFound(uint64_t key, const uint64_t startAddr) + { + return MapperBase::Remove(key, [&](uint64_t value) { + uint64_t memSize = emExpendMemInfoPtr->extEmbeddingSize * sizeof(float); + auto rc = memcpy_s(reinterpret_cast(startAddr), memSize, reinterpret_cast(value), memSize); + if (rc != 0) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "memcpy_s failed... dstSize: " + std::to_string(memSize)); + return BeforeRemoveFuncState::BEFORE_FAIL; + } + emExpendMemInfoPtr->GetValueToBeRecycled(value); + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + } + +private: + void FreeExpandedMemory() + { + for (auto &memUint : emExpendMemInfoPtr->expandedMemory) { + free(reinterpret_cast(memUint.address)); + } + } + +private: + uint32_t hostVocabSize; + std::shared_ptr emExpendMemInfoPtr; +}; +} +#endif // MXREC_FASTER_QUERY_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h new file mode 100644 index 00000000..363c59ee --- /dev/null +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -0,0 +1,785 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_MAPPER_BASE_H +#define MXREC_MAPPER_BASE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "securec.h" +#include "embedding_cache/common.h" +#include "external_logger.h" + +namespace EmbCache { +/* + * @brief Allocator template, for extend memory allocation for overflowed buckets + */ + +static constexpr size_t K_ALIGNMENT = 64; +static constexpr size_t K_KVNUMINBUCKET = 3; + +class NetHeapAllocator { +public: + void *Allocate(uint32_t size) + { + return calloc(1, size); + } + + void Free(void *p) + { + if (HM_LIKELY(p != nullptr)) { + free(p); + p = nullptr; + } + } +}; + +/* + * @brief Spin lock entry in bucket + * used for alloc overflowed buckets + */ + +struct NetHashLockEntry { + uint64_t lock = 0; + + /* + * @brief Spin lock + */ + void Lock() + { + while (!__sync_bool_compare_and_swap(&lock, 0, 1)) { + } + } + + /* + * @brief Unlock + */ + void UnLock() + { + __atomic_store_n(&lock, 0, __ATOMIC_SEQ_CST); + } +} __attribute__((packed)); + +/* + * @brief Store the key/value into a linked array with 6 items, + * because 64bytes is one cache line + */ + +struct alignas(K_ALIGNMENT)NetHashBucket { + std::atomic keys[K_KVNUMINBUCKET]{}; + uint64_t values[K_KVNUMINBUCKET]{}; + NetHashBucket *next = nullptr; + NetHashLockEntry spinLock{}; + + FkvState Put(uint64_t key, uint64_t &value, const std::function &beforePutFunc) + { + /* don't put them into loop, flat code is faster than loop */ + uint64_t oldKey = 0; + if (keys[0].load(std::memory_order_relaxed) == 0 && keys[0].compare_exchange_strong(oldKey, key)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + keys[0] = 0; + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + keys[0] = 0; + return FkvState::FKV_NO_SPACE; + } + values[0] = value; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(oldKey == key)) { + return FkvState::FKV_KEY_CONFLICT; + } + + oldKey = 0; + if (keys[1].load(std::memory_order_relaxed) == 0 && keys[1].compare_exchange_strong(oldKey, key)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + keys[1] = 0; + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + keys[1] = 0; + return FkvState::FKV_NO_SPACE; + } + values[1] = value; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(oldKey == key)) { + return FkvState::FKV_KEY_CONFLICT; + } + + oldKey = 0; + if (keys[2].load(std::memory_order_relaxed) == 0 && keys[2].compare_exchange_strong(oldKey, key)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + keys[2] = 0; + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + keys[2] = 0; + return FkvState::FKV_NO_SPACE; + } + values[2] = value; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(oldKey == key)) { + return FkvState::FKV_KEY_CONFLICT; + } + + return FkvState::FKV_FAIL; + } + + /* + * @brief Remove the address from the bucket and get size + */ + bool Find(const uint64_t key, uint64_t &value) + { + /* + * expand the loop, instead of put them into a for/while loop for performance + */ + if (key == keys[0].load(std::memory_order_relaxed)) { + value = values[0]; + return true; + } + + if (key == keys[1].load(std::memory_order_relaxed)) { + value = values[1]; + return true; + } + + if (key == keys[2].load(std::memory_order_relaxed)) { + value = values[2]; + return true; + } + + return false; + } + + FkvState Remove(uint64_t key) + { + /* don't put them into loop, flat code is faster than loop */ + uint64_t oldValue = key; + if (keys[0].load(std::memory_order_relaxed) == key && keys[0].compare_exchange_strong(oldValue, 0)) { + values[0] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[1].load(std::memory_order_relaxed) == key && keys[1].compare_exchange_strong(oldValue, 0)) { + values[1] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[2].load(std::memory_order_relaxed) == key && keys[2].compare_exchange_strong(oldValue, 0)) { + values[2] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + + return FkvState::FKV_NOT_EXIST; + } + + FkvState Remove(uint64_t key, const std::function &beforeRemoveFunc) + { + /* don't put them into loop, flat code is faster than loop */ + uint64_t oldValue = key; + if (keys[0].load(std::memory_order_relaxed) == key && keys[0].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[0]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + values[0] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[1].load(std::memory_order_relaxed) == key && keys[1].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[1]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + values[1] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + oldValue = key; + + if (keys[2].load(std::memory_order_relaxed) == key && keys[2].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[2]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + values[2] = 0; + return FkvState::FKV_EXIST; + } + if (HM_UNLIKELY(oldValue == 0)) { + return FkvState::FKV_EXIST; + } + + return FkvState::FKV_NOT_EXIST; + } +}; + + +class MapperBase { +public: + // DEFINE_RDMA_REF_COUNT_FUNCTIONS + std::atomic current_size{ 0 }; + + MapperBase() = default; + + ~MapperBase() = default; + + bool Initialize(uint32_t reserve) + { + /* already initialized */ + if (mOverflowEntryAlloc != nullptr) { + return true; + } + + /* get proper bucket count */ + uint32_t bucketCount = reserve < 128 ? 128 : reserve; + if (bucketCount > gPrimes[gPrimesCount - 1]) { + bucketCount = gPrimes[gPrimesCount - 1]; + } else { + uint32_t i = 0; + while (i < gPrimesCount && gPrimes[i] < bucketCount) { + i++; + } + bucketCount = gPrimes[i]; + } + + /* allocate buckets for sub-maps */ + for (auto &mSubMap : mSubMaps) { + auto tmp = new (std::nothrow) NetHashBucket[bucketCount]; + if (HM_UNLIKELY(tmp == nullptr)) { + FreeSubMaps(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "Failed to new hash bucket, probably out of memory"); + return false; + } + + /* make physical page and set to zero */ + auto ret = memset_s(tmp, sizeof(NetHashBucket) * bucketCount, 0, sizeof(NetHashBucket) * bucketCount); + if (ret != 0) { + ock::ExternalLogger::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "memset_s failed... size: " + std::to_string(sizeof(NetHashBucket) * bucketCount)); + return false; + } + + mSubMap = tmp; + } + + /* create overflow entry allocator */ + mOverflowEntryAlloc = new (std::nothrow) NetHeapAllocator(); + if (HM_UNLIKELY(mOverflowEntryAlloc == nullptr)) { + FreeSubMaps(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + "Failed to new overflow entry allocator, probably out of memory"); + return false; + } + + /* set bucket count */ + mBucketCount = bucketCount; + ock::ExternalLogger::PrintLog(ock::LogLevel::INFO, + "fastKV inited, mBucketCount: " + std::to_string(mBucketCount)); + return true; + } + + virtual void UnInitialize() + { + if (mOverflowEntryAlloc == nullptr) { + return; + } + + /* free overflowed entries firstly */ + FreeOverFlowedEntries(); + + /* free sub map secondly */ + FreeSubMaps(); + + /* free overflow entry at last */ + delete mOverflowEntryAlloc; + mOverflowEntryAlloc = nullptr; + mBucketCount = 0; + } + + FkvState FindAndPutIfNotFound(uint64_t key, uint64_t &value, + const std::function &beforePutFunc) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + value = zeroValue; + return FkvState::FKV_EXIST; + } + if (__sync_bool_compare_and_swap(&zeroInside, false, true)) { + BeforePutFuncState ret = beforePutFunc(); + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { + return FkvState::FKV_NO_SPACE; + } + zeroValue = value; + current_size++; + return FkvState::FKV_NOT_EXIST; + } + return FkvState::FKV_KEY_CONFLICT; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + while (buck != nullptr) { + buck->spinLock.Lock(); + if (buck->Find(key, value)) { + buck->spinLock.UnLock(); + return FkvState::FKV_EXIST; + } + buck->spinLock.UnLock(); + + if (buck->next != nullptr) { + buck = buck->next; + } else { + break; + } + } + + // did not find, now do put. continue from the last bucket in find + + /* try 8192 times */ + for (uint16_t i = 0; i < 8192; i++) { + /* loop all buckets linked */ + while (buck != nullptr) { + /* if there is an entry to put, just break */ + buck->spinLock.Lock(); + FkvState putRet = buck->Put(key, value, beforePutFunc); + buck->spinLock.UnLock(); + if (putRet == FkvState::FKV_NOT_EXIST) { + current_size++; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_KEY_CONFLICT)) { + return FkvState::FKV_KEY_CONFLICT; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_BEFORE_PUT_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_NO_SPACE)) { + return FkvState::FKV_NO_SPACE; + } + + /* + * if no next bucket exist, just for break, + * else move to next bucket linked + */ + if (buck->next == nullptr) { + break; + } else { + buck = buck->next; + } + } + + /* + * if not put successfully in existing buckets, allocate a new one + * + * NOTES: just allocate memory, don't access new bucket in the spin lock scope, + * if access new bucket, which could trigger physical memory allocation which + * could trigger page fault, that is quite slow. In this case, spin lock + * could occupy too much CPU + */ + auto &lock = buck->spinLock; + lock.Lock(); + /* if other thread allocated new buck already, unlock and continue */ + if (buck->next != nullptr) { + buck = buck->next; + lock.UnLock(); + continue; + } + + /* firstly entered thread allocate new bucket */ + auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); + if (HM_UNLIKELY(newBuck == nullptr)) { + lock.UnLock(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to allocate new bucket"); + return FkvState::FKV_FAIL; + } + /* link to current buck, set buck to new buck */ + buck->next = newBuck; + buck = newBuck; + + /* unlock */ + lock.UnLock(); + } + return FkvState::FKV_FAIL; + } + + FkvState Remove(uint64_t key) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + zeroValue = 0; + current_size--; + } + return FkvState::FKV_EXIST; + } + return FkvState::FKV_NOT_EXIST; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + uint64_t value; + while (buck != nullptr) { + if (buck->Find(key, value)) { + buck->Remove(key); + current_size--; + return FkvState::FKV_EXIST; + } + + buck = buck->next; + } + + return FkvState::FKV_NOT_EXIST; + } + + FkvState Remove(uint64_t key, const std::function &beforeRemoveFunc) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + auto ret = beforeRemoveFunc(zeroValue); + if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + zeroValue = 0; + current_size--; + } + return FkvState::FKV_EXIST; + } + return FkvState::FKV_NOT_EXIST; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + uint64_t value; + while (buck != nullptr) { + if (buck->Find(key, value)) { + auto ret = buck->Remove(key, beforeRemoveFunc); + if (HM_UNLIKELY(ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + + current_size--; + return FkvState::FKV_EXIST; + } + + buck = buck->next; + } + + return FkvState::FKV_NOT_EXIST; + } + + FkvState Put(uint64_t key, uint64_t value) + { + if (HM_UNLIKELY(key == 0)) { + if (__sync_bool_compare_and_swap(&zeroInside, false, true)) { + zeroValue = value; + current_size++; + return FkvState::FKV_NOT_EXIST; + } + return FkvState::FKV_KEY_CONFLICT; + } + + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + /* loop all buckets linked */ + while (buck != nullptr) { + if (buck->next != nullptr) { + buck = buck->next; + } else { + break; + } + } + + // did not find, now do put. continue from the last bucket in find + /* try 8192 times */ + for (uint16_t i = 0; i < 8192; i++) { + /* loop all buckets linked */ + while (buck != nullptr) { + /* if there is an entry to put, just break */ + FkvState putRet = buck->Put(key, value, []() -> BeforePutFuncState { return {}; }); + if (putRet == FkvState::FKV_NOT_EXIST) { + current_size++; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_KEY_CONFLICT)) { + return FkvState::FKV_KEY_CONFLICT; + } + /* + * if no next bucket exist, just for break, + * else move to next bucket linked + */ + if (buck->next == nullptr) { + break; + } else { + buck = buck->next; + } + } + + /* + * if not put successfully in existing buckets, allocate a new one + * + * NOTES: just allocate memory, don't access new bucket in the spin lock scope, + * if access new bucket, which could trigger physical memory allocation which + * could trigger page fault, that is quite slow. In this case, spin lock + * could occupy too much CPU + */ + auto &lock = buck->spinLock; + lock.Lock(); + /* if other thread allocated new buck already, unlock and continue */ + if (buck->next != nullptr) { + buck = buck->next; + lock.UnLock(); + continue; + } + + /* firstly entered thread allocate new bucket */ + auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); + if (HM_UNLIKELY(newBuck == nullptr)) { + lock.UnLock(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to allocate new bucket"); + return FkvState::FKV_FAIL; + } + /* link to current buck, set buck to new buck */ + buck->next = newBuck; + buck = newBuck; + + /* unlock */ + lock.UnLock(); + } + return FkvState::FKV_FAIL; + } + + bool Find(const uint64_t key, uint64_t &value) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + value = zeroValue; + return true; + } + return false; + } + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + /* loop all buckets linked */ + while (buck != nullptr) { + if (buck->Find(key, value)) { + return true; + } + + buck = buck->next; + } + + return false; + } + + /* When used in muti thread, this function can only be used when keys are uniqued */ + FkvState FindAndDeleteIfFound(const uint64_t key, uint64_t &value, + const std::function &beforeRemoveFunc) + { + if (HM_UNLIKELY(key == 0)) { + if (zeroInside) { + value = zeroValue; + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + auto ret = beforeRemoveFunc(zeroValue); + if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + zeroValue = 0; + current_size--; + } + + return FkvState::FKV_EXIST; + } + return FkvState::FKV_NOT_EXIST; + } + /* get bucket */ + auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); + + while (buck != nullptr) { + if (buck->Find(key, value)) { + auto ret = buck->Remove(key, beforeRemoveFunc); + if (HM_UNLIKELY(ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; + } + current_size--; + return FkvState::FKV_EXIST; + } + + buck = buck->next; + } + + return FkvState::FKV_NOT_EXIST; + } + + std::vector> ExportVec() + { + std::vector> kvVec; + if (zeroInside) { + kvVec.emplace_back(0, zeroValue); + } + for (auto &mSubMap : mSubMaps) { + for (uint32_t j = 0; j < mBucketCount; j++) { + auto buck = &mSubMap[j]; + while (buck) { + for (int k = 0; k < 3; k++) { + if (buck->keys[k] == 0) { + continue; + } + kvVec.emplace_back(buck->keys[k].load(), buck->values[k]); + } + buck = buck->next; + } + } + } + return kvVec; + } + +protected: + static constexpr uint16_t gSubMapCount = 5; /* count of sub map */ + static constexpr uint32_t gPrimesCount = 256; + + /* make sure the size of this class is 64 bytes, fit into one cache line */ + NetHeapAllocator *mOverflowEntryAlloc = nullptr; /* allocate overflowed entry in one bucket */ + NetHashBucket *mSubMaps[gSubMapCount]{}; /* sub map */ + uint32_t mBucketCount = 0; /* bucket count of each sub map */ + uint32_t mBaseSize = 4096; /* base size */ + bool zeroInside = false; + uint64_t zeroValue = 0; + + const uint32_t gPrimes[gPrimesCount] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, + 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + 97, 103, 109, 113, 127, 137, 139, 149, 157, 167, + 179, 193, 199, 211, 227, 241, 257, 277, 293, 313, + 337, 359, 383, 409, 439, 467, 503, 541, 577, 619, + 661, 709, 761, 823, 887, 953, 1031, 1109, 1193, 1289, + 1381, 1493, 1613, 1741, 1879, 2029, 2179, 2357, 2549, + 2753, 2971, 3209, 3469, 3739, 4027, 4349, 4703, 5087, + 5503, 5953, 6427, 6949, 7517, 8123, 8783, 9497, 10273, + 11113, 12011, 12983, 14033, 15173, 16411, 17749, 19183, + 20753, 22447, 24281, 26267, 28411, 30727, 33223, 35933, + 38873, 42043, 45481, 49201, 53201, 57557, 62233, 67307, + 72817, 78779, 85229, 92203, 99733, 107897, 116731, 126271, + 136607, 147793, 159871, 172933, 187091, 202409, 218971, 236897, + 256279, 277261, 299951, 324503, 351061, 379787, 410857, 444487, + 480881, 520241, 562841, 608903, 658753, 712697, 771049, 834181, + 902483, 976369, 1056323, 1142821, 1236397, 1337629, 1447153, + 1565659, 1693859, 1832561, 1982627, 2144977, 2320627, 2510653, + 2716249, 2938679, 3179303, 3439651, 3721303, 4026031, 4355707, + 4712381, 5098259, 5515729, 5967347, 6456007, 6984629, 7556579, + 8175383, 8844859, 9569143, 10352717, 11200489, 12117689, + 13109983, 14183539, 15345007, 16601593, 17961079, 19431899, + 21023161, 22744717, 24607243, 26622317, 28802401, 31160981, + 33712729, 36473443, 39460231, 42691603, 46187573, 49969847, + 54061849, 58488943, 63278561, 68460391, 74066549, 80131819, + 86693767, 93793069, 101473717, 109783337, 118773397, 128499677, + 139022417, 150406843, 162723577, 176048909, 190465427, + 206062531, 222936881, 241193053, 260944219, 282312799, + 305431229, 330442829, 357502601, 386778277, 418451333, + 452718089, 489790921, 529899637, 573292817, 620239453, + 671030513, 725980837, 785430967, 849749479, 919334987, + 994618837, 1076067617, 1164186217, 1259520799, 1362662261, + 1474249943, 1594975441, 1725587117, 1866894511, 2019773507, + 2185171673, 2364114217, 2557710269, 2767159799, 2993761039, + 3238918481, 3504151727, 3791104843, 4101556399, 4294967291}; + +private: + void FreeSubMaps() + { + /* free all sub maps */ + for (auto &mSubMap : mSubMaps) { + if (mSubMap != nullptr) { + delete[] mSubMap; + mSubMap = nullptr; + } + } + } + + void FreeOverFlowedEntries() + { + for (auto &mSubMap : mSubMaps) { + if (mSubMap == nullptr) { + continue; + } + + /* free overflow entries in one sub map */ + for (uint32_t buckIndex = 0; buckIndex < mBucketCount; ++buckIndex) { + auto curBuck = mSubMap[buckIndex].next; + NetHashBucket *nextOverflowEntryBuck = nullptr; + + /* exit loop when curBuck is null */ + while (curBuck != nullptr) { + /* assign next overflow buck to tmp variable */ + nextOverflowEntryBuck = curBuck->next; + + /* free this overflow bucket */ + mOverflowEntryAlloc->Free(curBuck); + + /* assign next to current */ + curBuck = nextOverflowEntryBuck; + } + } + } + } +}; +} +#endif // MXREC_MAPPER_BASE_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h new file mode 100644 index 00000000..39dba5c2 --- /dev/null +++ b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h @@ -0,0 +1,209 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_OFFSET_MAPPER_H +#define MXREC_OFFSET_MAPPER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mapper_base.h" + +namespace EmbCache { +class OffsetMapper : public MapperBase { +public: + OffsetMapper() = default; + + ~OffsetMapper() = default; + + bool Initialize(uint32_t reserve, uint32_t maxSize = 0) + { + maxCacheSize = maxSize; + useLength = 0; + pos2Key.resize(maxSize); + std::fill(pos2Key.begin(), pos2Key.end(), INVALID_KEY); + try { + validPos = new LimitedSet(maxSize); + evictPos = new LimitedSet(maxSize); + } catch (const std::bad_alloc &e) { + return false; + } + return MapperBase::Initialize(reserve); + } + + void UnInitialize() override + { + delete validPos; + delete evictPos; + MapperBase::UnInitialize(); + } + + FkvState Remove(uint64_t key) + { + return MapperBase::Remove(key, [&](uint64_t value) { + validPos->remove(value); + auto pos = std::find(lastBatchPos.begin(), lastBatchPos.end(), value); + if (pos != lastBatchPos.end()) { + lastBatchPos.erase(pos); + } + evictPos->insert(value); + evictSize++; + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + } + + std::vector> ExportSortedKVPairs() + { + auto koVec = ExportVec(); + std::sort(koVec.begin(), koVec.end(), [](const auto &u, const auto &v) { return u.second < v.second; }); + return koVec; + } + + uint64_t GetFreeLength() + { + return maxCacheSize - useLength + evictSize; + } + + int GetSwapPairsAndKey2Offset(std::vector &keys, std::vector &swapInKeys, + std::vector &swapInPos, std::vector &swapOutKeys, std::vector &swapOutPos) + { + std::vector swapInKeysID; + + for (uint64_t i = 0; i < keys.size(); i++) { + // Invalid key 不考虑 + if (HM_UNLIKELY(keys[i] == static_cast(INVALID_KEY))) { + continue; + } + // 在HBM中的key, 原地替换为pos后从validPos中移除 + // 不在HBM中的key,加入swapInKeys,并记录在keys中的下标,用于后续key->offset + if (Find(keys[i], keys[i])) { + validPos->remove(keys[i]); + } else { + swapInKeys.push_back(keys[i]); + swapInKeysID.push_back(i); + } + } + + swapInPos.resize(swapInKeys.size()); + // 换出量 = 换入量 - 剩余空间 + uint64_t swapOutNum = swapInKeys.size() <= GetFreeLength() ? 0 : swapInKeys.size() - GetFreeLength(); + swapOutKeys.resize(swapOutNum); + swapOutPos.resize(swapOutNum); + + // 空间不足,前swapOutNum个Key从evictPos中拿可换出位置 + uint64_t swapInCnt = 0; + for (uint64_t pos : *evictPos) { + if (swapInCnt == swapInKeys.size()) { + break; + } + // 记录swapInPos + swapInPos[swapInCnt] = pos; + // key->offset + keys[swapInKeysID[swapInCnt]] = pos; + // 放入新key-pos + Put(swapInKeys[swapInCnt], pos); + // 更新pos2Key + pos2Key[pos] = swapInKeys[swapInCnt]; + swapInCnt++; + evictSize--; + } + + uint64_t swapOutCnt = 0; + // 空间不足,前swapOutNum个Key从validPos中拿可换出位置 + for (uint64_t pos : *validPos) { + if (swapOutCnt == swapOutNum) { + break; + } + // 记录swapInPos + swapInPos[swapInCnt] = pos; + // key->offset + keys[swapInKeysID[swapInCnt]] = pos; + // 删除原key-pos,放入新key-pos + uint64_t key = pos2Key[pos]; + MapperBase::Remove(key); + Put(swapInKeys[swapInCnt], pos); + // 记录swapOutKoPair + swapOutKeys[swapOutCnt] = key; + swapOutPos[swapOutCnt] = pos; + // 更新pos2Key + pos2Key[pos] = swapInKeys[swapInCnt]; + swapInCnt++; + swapOutCnt++; + } + + if (swapOutCnt < swapOutNum) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "max cache size is too small"); + return ock::ctr::H_MAX_CACHESIZE_TOO_SMALL; + } + + // 剩下的Key从om中分配位置 + for (uint64_t i = swapInCnt; i < swapInKeys.size(); i++) { + swapInPos[i] = useLength++; + if (HM_UNLIKELY(swapInPos[i] >= maxCacheSize)) { + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "max cache size is too small"); + return ock::ctr::H_MAX_CACHESIZE_TOO_SMALL; + } + // 放入新key-pos + Put(swapInKeys[i], swapInPos[i]); + // 更新pos2Key + pos2Key[swapInPos[i]] = swapInKeys[i]; + // key->offset + keys[swapInKeysID[i]] = swapInPos[i]; + } + + // 上个batch中的pos可被换出,加入validPos中 + for (uint64_t pos : lastBatchPos) { + if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { + continue; + } + validPos->insert(pos); + } + + // 这里keys都已被替换成offset,这个batch使用的pos在下个batch不能被换出,移出validPos + for (uint64_t pos : keys) { + if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { + continue; + } + validPos->remove(pos); + evictPos->remove(pos); + } + + lastBatchPos = keys; + return ock::ctr::H_OK; + } + + uint32_t GetUsage() + { + return useLength - evictSize; + } + +private: + uint32_t maxCacheSize{}; // HBM可容纳embedding条数 + uint32_t useLength{}; // HBM存储的embedding条数 + LimitedSet *validPos{}; // HBM中可被换出的位置 + LimitedSet *evictPos{}; // 淘汰出的位置 + std::vector pos2Key; // HBM中每个位置对应的key + std::vector lastBatchPos; // 上个batch的keys在HBM中占用的pos + uint64_t evictSize; // evictPos的长度 +}; +} +#endif // MXREC_OFFSET_MAPPER_H diff --git a/src/AccCTR/src/factory_impl.cpp b/src/AccCTR/src/factory_impl.cpp index f0f5cdac..654e1d76 100644 --- a/src/AccCTR/src/factory_impl.cpp +++ b/src/AccCTR/src/factory_impl.cpp @@ -54,6 +54,17 @@ int FactoryImpl::CreateUnique(std::shared_ptr &out) return H_OK; } +int FactoryImpl::CreateEmbCacheManager(std::shared_ptr &out) +{ + auto tmp = new (std::nothrow) EmbCache::EmbCacheManagerImpl(); + if (tmp == nullptr) { + return H_NEW_OBJECT_FAILED; + } + + out.reset(dynamic_cast(tmp)); + return H_OK; +} + int FactoryImpl::SetExternalLogFuncInner(ExternalLog logFunc) { auto logger = ExternalLogger::Instance(); diff --git a/src/AccCTR/src/factory_impl.h b/src/AccCTR/src/factory_impl.h index cc1c025a..aa5cd211 100644 --- a/src/AccCTR/src/factory_impl.h +++ b/src/AccCTR/src/factory_impl.h @@ -17,6 +17,7 @@ limitations under the License. #include "include/factory.h" #include "unique/unique_impl.h" +#include "embedding_cache/cache_manager/cache_manager.h" namespace ock { namespace ctr { @@ -27,6 +28,7 @@ public: public: int CreateUnique(std::shared_ptr &out) override; + int CreateEmbCacheManager(std::shared_ptr &out) override; int SetExternalLogFuncInner(ExternalLog logFunc) override; public: diff --git a/src/AccCTR/src/include/CMakeLists.txt b/src/AccCTR/src/include/CMakeLists.txt index c9d2b215..7f8b2b6d 100644 --- a/src/AccCTR/src/include/CMakeLists.txt +++ b/src/AccCTR/src/include/CMakeLists.txt @@ -12,7 +12,7 @@ # limitations under the License. # ============================================================================== -set(INCLUDE_HEADERS factory.h ock_ctr_common_def.h unique.h) +set(INCLUDE_HEADERS factory.h ock_ctr_common_def.h unique.h embedding_cache.h) set(TARGET_INSTALL_INCLUDE ${OUTPUT}/ock_ctr_common/include) diff --git a/src/AccCTR/src/include/embedding_cache.h b/src/AccCTR/src/include/embedding_cache.h new file mode 100644 index 00000000..69a41136 --- /dev/null +++ b/src/AccCTR/src/include/embedding_cache.h @@ -0,0 +1,295 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef EMBEDDING_CACHE_H +#define EMBEDDING_CACHE_H + +#include +#include +#include +#include + +namespace EmbCache { +using KeyOffsetPair = std::pair, std::vector>; + +class Initializer { +public: + Initializer() = default; + virtual ~Initializer() = default; + + /* * + * 生成随机数 + * @Param emb embedding的首地址 + */ + virtual void GenerateData(float *emb) = 0; + uint32_t start{}; // 起始位置 + uint32_t len{}; // 初始化的长度 + float initParam = 1.0; // 初始化器生成的初始值均需要乘以initParam +}; + +enum class InitializerType { + INVALID, + CONSTANT, + TRUNCATED_NORMAL, + RANDOM_NORMAL +}; + +struct ConstantInitializerInfo { + ConstantInitializerInfo() = default; + + explicit ConstantInitializerInfo(float constantValue, float initK); + + float constantValue{}; // 常量值 + float initK = 1.0; // 初始化出来的值需乘以initK +}; + +struct NormalInitializerInfo { + NormalInitializerInfo() = default; + + NormalInitializerInfo(float mean, float stddev, uint32_t seed, float initK); + + float mean{}; // 平均值 + float stddev{}; // 标准差 + uint32_t seed{}; // 随机数种子 + float initK = 1.0; // 初始化出来的值需乘以initK +}; + +class ConstantInitializer : public Initializer { +public: + ConstantInitializer() = default; + + ConstantInitializer(uint32_t start, uint32_t len, float value, float initK); + + ~ConstantInitializer() override = default; + + void GenerateData(float *emb) override; + + uint32_t start{}; // 起始位置 + uint32_t len{}; // 初始化的长度 + float constantValue{}; // 常量值 +}; + +class RandomNormalInitializer : public Initializer { +public: + RandomNormalInitializer() = default; + RandomNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo &initInfo); + + ~RandomNormalInitializer() override = default; + + void GenerateData(float * const emb) override; + + uint32_t start{}; // 起始位置 + uint32_t len{}; // 初始化的长度 + float mean{}; // 平均值 + float stddev{}; // 标准差 + uint32_t seed{}; // 随机数种子 + + std::default_random_engine generator; // 随机数生成器 + std::normal_distribution distribution; // 正态分布 +}; + +class TruncatedNormalInitializer : public Initializer { +public: + TruncatedNormalInitializer() = default; + + TruncatedNormalInitializer(uint32_t start, uint32_t len, NormalInitializerInfo &initInfo); + + ~TruncatedNormalInitializer() override = default; + + void GenerateData(float *emb) override; + + int boundNum = 2; + + uint32_t start{}; // 起始位置 + uint32_t len{}; // 初始化的长度 + float mean{}; // 平均值 + float stddev{}; // 标准差 + uint32_t seed{}; // 随机数种子 + + std::default_random_engine generator; // 随机数生成器 + std::normal_distribution distribution; + float minBound = 0; // 下界 + float maxBound = 0; // 上界 +}; + +struct InitializerInfo { + InitializerInfo() = default; + + InitializerInfo(std::string &name, uint32_t start, uint32_t len, ConstantInitializerInfo constantInitializerInfo); + + InitializerInfo(std::string &name, uint32_t start, uint32_t len, NormalInitializerInfo normalInitializerInfo); + + std::string name; // 初始化器的名称 + uint32_t start{}; // 初始化开始的位置 + uint32_t len{}; // 待初始化的长度 + InitializerType initializerType = InitializerType::INVALID; + + ConstantInitializerInfo constantInitializerInfo; + NormalInitializerInfo normalInitializerInfo; + + std::shared_ptr initializer; +}; + +struct EmbCacheInfo { + EmbCacheInfo(std::string tableName, uint32_t vocabSize, uint32_t embeddingSize, uint32_t extEmbeddingSize, + uint32_t maxCacheSize) + : tableName(tableName), + vocabSize(vocabSize), + embeddingSize(embeddingSize), + extEmbeddingSize(extEmbeddingSize), + maxCacheSize(maxCacheSize) + {} + std::string tableName; + uint32_t vocabSize; // host侧的容量(能存多少条embedding) + uint32_t embeddingSize; + uint32_t extEmbeddingSize; // 包含embedding和优化器信息的embedding长度 + uint32_t maxCacheSize; // device侧的容量(能存多少条embedding) +}; + +class EmbCacheManager { +public: + virtual ~EmbCacheManager() = default; + + /* * + * 对当前embInfo对应的table在cache_manager中进行table初始化 + * @Param EmbCacheInfo: embedding cache的初始化信息 + * @Param std::vector 初始化器的信息 + * @Param uint64_t prefillBufferSize emb内存池恒定可用大小 + * @Param uint32_t refillThreadNum emb内存池自动填充线程数 + * @Return errorCode + */ + virtual int CreateCacheForTable(const EmbCacheInfo &embCacheInfo, + const std::vector &initializerInfos, int64_t invalidKey = -1, + uint64_t prefillBufferSize = 500000, uint32_t refillThreadNum = 1) = 0; + + /* * + * 查找当前keys对应的offsets并将本不存在与offsetMapper中的keys插入到offsetMapper中并得到其偏移值offsets, + * 并且当offsetMapper可存放空间不足时,释放可swapOut的keys,获取当前需要被换入换出的keys和offsets的pair + * @Param tableName: 表名 + * @Param keys: 当前batch所有unique的keys + * @Param swapInKoPair: 输出参数,需要换入的Key-offset pair + * @Param swapOutKoPair: 输出参数,需要换出的Key-offset pair + * @Return errorCode + */ + virtual int GetSwapPairsAndKey2Offset(std::string tableName, std::vector &keys, + KeyOffsetPair &swapInKoPair, KeyOffsetPair &swapOutKoPair) = 0; + + /* * + * 查询Embedding + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param embAddr: 申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 查询Embedding的地址 + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param addrs: keys对应的申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookupAddrs(std::string tableName, const std::vector &keys, + std::vector &addrs, uint32_t threadNum = 4) = 0; + + /* * + * 查询Embedding并且在查询完成之后删除embedding对应的key。如果多线程使用,严格保证传入的key线程间不会重复(unique + * key),否则可能出现未定义结果 + * @Param tableName: 表名 + * @Param keys: 待查询的keys + * @Param embAddr: 申请出来存放embedding的空间首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingLookupAndRemove(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 更新Embedding + * @Param tableName: 表名 + * @Param keys: 待更新的keys,用于查询出每个key在DDR上存放的地址 + * @Param embAddr: 待更新到DDR上的embedding的首地址 + * @Param threadNum: 线程数 + * @Return errorCode + */ + virtual int EmbeddingUpdate(std::string tableName, const std::vector &keys, float *embAddr, + uint32_t threadNum = 4) = 0; + + /* * + * 在EmbLocalTable中移除keys,并将存储其embedding的内存位置记为可复用 + * @Param tableName: 表名 + * @Param keys: 待移除的keys + * @Return errorCode + */ + virtual int EmbeddingRemove(std::string tableName, const std::vector &keys, uint32_t threadNum = 4) = 0; + + /* * + * 将需要被淘汰的keys从offsetMapper的记录中移除,同时也在EmbLocalTable中移除,并将存储其embedding的内存位置记为可复用 + * @Param tableName: 表名 + * @Param keys: 待淘汰的keys + * @Return errorCode + */ + virtual int RemoveEmbsByKeys(std::string tableName, const std::vector &keys) = 0; + + /* * + * 获取所有table names + * @Param allTableNames: 输出参数,用于存放所有的table names + * @Return errorCode + */ + virtual int GetEmbTableNames(std::vector &allTableNames) = 0; + + /* * + * 获取以values为增序排列的当前记录在offsetMapper中所有的keys和values的pairs + * @Param tableName: 表名 + * koVec: 输出参数 + * @Return errorCode + */ + virtual int ExportDeviceKeyOffsetPairs(std::string tableName, + std::vector> &koVec) = 0; + + /* * + * 获取当前table的序列化信息 + * @Param tableName: 要序列化的表 + * @Param buffer: 输出参数,存储序列化之后的信息 + * @Return errorCode + */ + virtual int Serialize(std::string tableName, std::vector &buffer) = 0; + + /* * + * 将当前table的序列化信息进行反序列化 + * @Param tableName: 要反序列化的表 + * @Param buffer: 输入参数,将buffer中的内容进行反序列化 + * @Return errorCode + */ + virtual int Deserialize(std::string tableName, const std::vector &buffer) = 0; + + /* * + * 析构所有embCache,释放内存 + */ + virtual void Destroy() = 0; + + /* * + * 查询表的使用量 + * @Param tableName: 要查询的表 + * @Return 当前表的使用量 + */ + virtual uint32_t GetUsage(const std::string &tableName) = 0; +}; +} + +#endif // EMBEDDING_CACHE_H diff --git a/src/AccCTR/src/include/factory.h b/src/AccCTR/src/include/factory.h index 14732cf9..69e8217a 100644 --- a/src/AccCTR/src/include/factory.h +++ b/src/AccCTR/src/include/factory.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include "unique.h" +#include "embedding_cache.h" #ifdef __cplusplus @@ -39,11 +40,13 @@ class Factory; using FactoryPtr = std::shared_ptr; using UniquePtr = std::shared_ptr; +using EmbCacheManagerPtr = std::shared_ptr; class Factory { public: virtual ~Factory() = default; virtual int CreateUnique(UniquePtr &out) = 0; + virtual int CreateEmbCacheManager(EmbCacheManagerPtr &out) = 0; virtual int SetExternalLogFuncInner(ExternalLog logFunc) = 0; public: @@ -52,7 +55,7 @@ public: int result = 0; uintptr_t factory = 0; /* dynamic load function */ - if ((result = OckCtrCommonDef::CreatFactory(&factory)) == 0) { + if ((result = OckCtrCommonDef::CreateFactory(&factory)) == 0) { out.reset(reinterpret_cast(factory)); } return result; diff --git a/src/AccCTR/src/include/ock_ctr_common_def.h b/src/AccCTR/src/include/ock_ctr_common_def.h index ed955996..75e7e9cb 100644 --- a/src/AccCTR/src/include/ock_ctr_common_def.h +++ b/src/AccCTR/src/include/ock_ctr_common_def.h @@ -25,7 +25,7 @@ namespace ock { namespace ctr { class OckCtrCommonDef { public: - static int CreatFactory(uintptr_t *factory) + static int CreateFactory(uintptr_t *factory) { static void *handle = nullptr; static std::mutex m; diff --git a/src/AccCTR/src/include/unique.h b/src/AccCTR/src/include/unique.h index 3154a784..1f58f8a4 100644 --- a/src/AccCTR/src/include/unique.h +++ b/src/AccCTR/src/include/unique.h @@ -58,6 +58,7 @@ using UniqueConf = struct UniqueConfCTR { uint32_t maxThreadNum = 8; // 最大工作线程数 int64_t maxIdVal = 0; // 最大id值 bool trace = false; // 是否开启性能检测,需要配合外部日志输出 + bool performance = false; // 是否开启增强接口,增强接口shardingNum必须是2的幂次方,默认用取模分桶 } __attribute__((packed)); using UniqueIn = struct UniqueInCTR { diff --git a/src/AccCTR/src/unique/unique_func.cpp b/src/AccCTR/src/unique/unique_func.cpp index d208eac9..462d6f9e 100644 --- a/src/AccCTR/src/unique/unique_func.cpp +++ b/src/AccCTR/src/unique/unique_func.cpp @@ -27,7 +27,6 @@ void Dedup::Insert(uint64_t val) for (int8_t i = 0; i < count; ++i) { if (bucket->data[totalCount] == val) { - TryIncreaseIdCount(bucket->idCount[totalCount]); // found one return; } @@ -38,7 +37,6 @@ void Dedup::Insert(uint64_t val) std::lock_guard lg(bucket->lock); for (int8_t j = totalCount; j < bucket->count; ++j) { if (bucket->data[totalCount] == val) { - TryIncreaseIdCount(bucket->idCount[totalCount]); // found one return; } @@ -47,7 +45,6 @@ void Dedup::Insert(uint64_t val) if (totalCount < n) { bucket->data[totalCount] = val; bucket->count++; - TryIncreaseIdCount(bucket->idCount[totalCount]); return; } } @@ -55,13 +52,6 @@ void Dedup::Insert(uint64_t val) InsertOverflow(val); } -inline void Dedup::TryIncreaseIdCount(std::atomic &val) -{ - if (idCountEnable_) { - val++; - } -} - int32_t Dedup::GetReplaceOffsetUnsafe(uint64_t val) { auto h = static_cast(Hash(val) & bucketCountMask_); @@ -108,7 +98,6 @@ void Dedup::Clear(uint64_t newBucketCountPowerOf2) } bzero(table_, sizeof(Meta) * bucketCount_); overflow_.clear(); - idCountOverflow_.clear(); } void Dedup::NewParameter() @@ -168,6 +157,51 @@ int32_t ShardedDedup::GetFillOffset(const std::vector &totalUniqueSize, } } +void ShardedDedup::GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, + int &index) +{ + if (shardingNumber > 0) { + index += uniqueSizeInBucket[shardingNumber - 1]; + } + + if (usePadding) { + start = shardingNumber * conf.paddingSize; + } else { + start = index; + } +} + +int ShardedDedup::PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg) +{ + if (rc != 0) { + std::stringstream ssm; + ssm << "[" << logMsg << "] memcpy_s failed... dstSize: " << dstSize; + ExternalLogger::PrintLog(LogLevel::ERROR, ssm.str()); + return H_COPY_ERROR; + } else { + return H_OK; + } +} + +int ShardedDedup::HandleIdCountFill(std::vector> &idCount, UniqueOutSelf &uniqueOut) +{ + if (conf.usePadding) { + uint32_t memSize = idCount.size() * sizeof(int32_t); + auto rc = memcpy_s(uniqueOut.idCntFill, memSize, (int32_t *)(idCount.data()), memSize); + int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCntFill]"); + if (ret != 0) { + return ret; + } + } else { + uint32_t memSize = idCount.size() * sizeof(int32_t); + auto rc = memcpy_s(uniqueOut.idCnt, memSize, (int32_t *)(idCount.data()), memSize); + int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCnt]"); + if (ret != 0) { + return ret; + } + } + return H_OK; +} size_t ShardedDedup::CalThreadNum() const { diff --git a/src/AccCTR/src/unique/unique_func.h b/src/AccCTR/src/unique/unique_func.h index 07c8ebb7..4812f74c 100644 --- a/src/AccCTR/src/unique/unique_func.h +++ b/src/AccCTR/src/unique/unique_func.h @@ -30,6 +30,7 @@ limitations under the License. #include #include #include +#include #include "securec.h" #include "common_includes.h" @@ -37,6 +38,14 @@ limitations under the License. namespace ock { namespace ctr { +#ifndef LIKELY +#define LIKELY(x) __builtin_expect(!!(x), 1) +#endif + +#ifndef UNLIKELY +#define UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif + using UniqueOutSelf = struct UniqueSelf { void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 @@ -47,7 +56,7 @@ using UniqueOutSelf = struct UniqueSelf { int uniqueIdCnt = 0; // 每个桶去重后的id个数(需要用户申请) }; -constexpr int UNIQUE_MAX_BUCKET_WIDTH = 5; +constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; template struct Map {}; template <> struct Map { @@ -111,7 +120,7 @@ class Dedup { static constexpr uint32_t K_MINIMAL_WORKLOAD_PER_WORKER = 1 << 12; static constexpr size_t K_ALIGNMENT = 64; static const int kDefaultBucketCount = 1 << 24; - static const int8_t n = 4; + static const int8_t n = UNIQUE_MAX_BUCKET_WIDTH; template struct Meta { static_assert(M <= UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); @@ -119,7 +128,6 @@ class Dedup { volatile int8_t count {}; uint32_t replaceBase {}; volatile uint64_t data[M] {}; - std::atomic idCount[M] {}; } __attribute__((__aligned__(64))); struct Statistics { @@ -152,11 +160,10 @@ public: void Insert(uint64_t val); int32_t GetReplaceOffsetUnsafe(uint64_t val); void InitTable(); - void TryIncreaseIdCount(std::atomic &val); void Clear(uint64_t newBucketCountPowerOf2); void NewParameter(); - template uint32_t UniqueRaw(void *output, uint32_t priorTotal, int32_t *idCount) + template uint32_t UniqueRaw(void *output, uint32_t priorTotal) { uint32_t total = priorTotal; uint32_t replaceOffset = priorTotal; @@ -168,19 +175,13 @@ public: } bucket->replaceBase = replaceOffset; for (int j = 0; j < bucket->count; ++j) { - if (idCountEnable_) { - idCount[total] = bucket->idCount[j]; - } - out[total++] = static_cast::type>(bucket->data[j]); + out[total++] = bucket->data[j]; } replaceOffset += bucket->count; } auto it = overflow_.begin(); int32_t totalOverflow = 0; while (it != overflow_.end()) { - if (idCountEnable_) { - idCount[total] = static_cast(idCountOverflow_[it->first]); - } out[total++] = it->first; it->second = replaceOffset++; ++it; @@ -189,7 +190,7 @@ public: // set total overflow count stats_.totalUniques = static_cast(total - priorTotal); - stats_.totalOverflowUniques = static_cast(totalOverflow); + stats_.totalOverflowUniques = totalOverflow; return total - priorTotal; } @@ -200,14 +201,13 @@ private: int largeCount_ { 0 }; Meta *table_ {}; std::unordered_map overflow_; - std::unordered_map idCountOverflow_; SpinLockG overflowMutex_; Statistics stats_; bool idCountEnable_ { false }; static inline uint64_t Hash(uint64_t val) { - return val ^ (val >> HASH_L_L) ^ (val >> HASH_L_L) ^ (val >> HASH_H); + return val ^ (val >> HASH_L_L) ^ (val >> HASH_L) ^ (val >> HASH_H); } void InsertOverflow(uint64_t val) @@ -217,10 +217,6 @@ private: if (it == overflow_.end()) { overflow_[val] = 0; } - - if (idCountEnable_) { - idCountOverflow_[val]++; - } } int32_t GetReplaceOffsetFromOverflowUnsafe(uint64_t val) @@ -234,6 +230,7 @@ class ShardedDedup { static constexpr uint32_t K_MINIMAL_WORKLOAD_PER_WORKER = 1 << 13; static constexpr int K_DEFAULT_DUPLICATE_RATIO = 4; static constexpr int K_BUCKET_WIDTH = 4; + static constexpr int CLEAR_WAIT_TIME = 10; public: using DedupT = Dedup; @@ -244,44 +241,45 @@ public: { const int numOfGroupsInShard = groupMethod_.GroupCount(); uint32_t totalSize = conf.desiredSize + (conf.desiredSize >> 1); - while (bucketCountPower2_ * static_cast(K_BUCKET_WIDTH) * - static_cast(numOfGroupsInShard) * static_cast(estimatedDuplicateRatio) < totalSize) { + while (bucketCountPower2_ * K_BUCKET_WIDTH * numOfGroupsInShard * estimatedDuplicateRatio < totalSize) { bucketCountPower2_ <<= 1; } idCountEnable_ = (conf.outputType == OutputType::ENHANCED) && conf.useIdCount; - try { - for (int32_t i = 0; i < numOfGroupsInShard; ++i) { - auto obj = new DedupT(bucketCountPower2_, numOfGroupsInShard, idCountEnable_); - dedupShards_.emplace_back(obj); + for (int32_t i = 0; i < numOfGroupsInShard; ++i) { + auto obj = new DedupT(bucketCountPower2_, numOfGroupsInShard, idCountEnable_); + if (obj == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "creat object error"); + throw NullptrError(); } - } catch (const std::bad_alloc& e) { - ExternalLogger::PrintLog(LogLevel::ERROR, "Memory allocation failed during loop: " + std::string(e.what())); - throw; + dedupShards_.emplace_back(obj); } } ~ShardedDedup() = default; - void StartNewRound() + int StartNewRound() { for (auto &s : dedupShards_) { s->NewParameter(); } + clearFinish_ = true; + return 0; } public: template int Compute(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut) { - try { - if (!firstEnterFlag_) { - StartNewRound(); - } - } catch (AllocError &) { - ExternalLogger::PrintLog(LogLevel::ERROR, "memory alloc error"); - return H_MEMORY_ALLOC_ERROR; + if (firstEnter_) { + pool_.SetNumThreads(1); + firstEnter_ = false; } - firstEnterFlag_ = false; + + while (!clearFinish_) { + usleep(CLEAR_WAIT_TIME); + } + + clearFinish_ = false; size_t threadNum = CalThreadNum(); partSize = (uniqueIn.inputIdCnt + threadNum - 1) / threadNum; @@ -304,23 +302,29 @@ public: if (conf.outputType == OutputType::ENHANCED) { int totalNumber = 0; for (int i = 0; i < conf.shardingNum; i++) { - totalUniqueSize[i] = static_cast(totalNumber); + totalUniqueSize[i] = totalNumber; if (conf.useSharding) { totalNumber += uniqueOut.uniqueIdCntInBucket[i]; } } } - ret = CalUniqueOut(uniqueIn, uniqueOut, totalUniqueSize); + int size = 1; + if (conf.useIdCount) { + size = conf.usePadding ? conf.paddingSize * conf.shardingNum : uniqueOut.uniqueIdCnt; + } + std::vector> idCount(size); + ret = CalUniqueOut(uniqueIn, uniqueOut, totalUniqueSize, idCount); if (ret != H_OK) { ExternalLogger::PrintLog(LogLevel::ERROR, "CalUniqueOut ERROR"); return ret; } if (conf.outputType == OutputType::ENHANCED) { - HandleTileAndFill(uniqueIn, uniqueOut); + HandleTileAndFill(uniqueOut, idCount); } + pool_.AddTask([this]() { return StartNewRound(); }); return H_OK; } @@ -336,17 +340,22 @@ private: int32_t GetFillOffset(const std::vector &totalUniqueSize, int64_t val, int32_t group); - template int HandleTileAndFill(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut) + void GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, + int &index); + + int PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg); + + int HandleIdCountFill(std::vector> &idCount, UniqueOutSelf &uniqueOut); + + template int HandleTileAndFill(UniqueOutSelf &uniqueOut, std::vector> &idCount) { int ret = H_OK; if (conf.useSharding) { // 使能shard - ret = TileAndFill(uniqueOut.uniqueIdInBucket, uniqueOut.uniqueIdCntInBucket, uniqueOut.uniqueId, - uniqueOut.idCnt, uniqueOut.idCntFill); + ret = TileAndFill(uniqueOut, uniqueOut.uniqueIdCntInBucket, idCount); } else if (!conf.useSharding && conf.useIdCount) { // 不使能shard和使能特征计数 std::vector count; count.emplace_back(uniqueOut.uniqueIdCnt); // 记录去重后id个数 - ret = TileAndFill(uniqueOut.uniqueId, count.data(), uniqueOut.uniqueId, uniqueOut.idCnt, - uniqueOut.idCntFill); + ret = TileAndFill(uniqueOut, count.data(), idCount); } if (ret != H_OK) { @@ -365,37 +374,37 @@ private: uint64_t inGroupTotal; if (conf.outputType == OutputType::ENHANCED) { if (conf.useSharding && conf.useIdCount) { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total, - uniqueOut.idCnt); // 特征计数使能和shard同时使能 - uniqueOut.uniqueIdCntInBucket[j] = static_cast(inGroupTotal); + inGroupTotal = + dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total); // 特征计数使能和shard同时使能 + uniqueOut.uniqueIdCntInBucket[j] = inGroupTotal; } else if (!conf.useSharding && conf.useIdCount) { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, - uniqueOut.idCnt); // 特征计数使能和shard不使能 + inGroupTotal = + dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total); // 特征计数使能和shard不使能 } else if (conf.useSharding && !conf.useIdCount) { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total, - nullptr); // 特征计数使能和shard不使能 - uniqueOut.uniqueIdCntInBucket[j] = static_cast(inGroupTotal); + inGroupTotal = + dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total); // 特征计数使能和shard不使能 + uniqueOut.uniqueIdCntInBucket[j] = inGroupTotal; } else { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, - nullptr); // 特征计数不使能和shard不使能,跟普通unique对等 + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, + total); // 特征计数不使能和shard不使能,跟普通unique对等 } } else { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, nullptr); + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total); } - total += static_cast(inGroupTotal); + total += inGroupTotal; } uniqueOut.uniqueIdCnt = total; } template - int TileAndFill(void *uniqueIdInBucket, const int32_t *uniqueSizeInBucket, void *uniqueIds, const int32_t *idCnt, - int32_t *idCntFill) + int TileAndFill(UniqueOutSelf &uniqueOut, const int32_t *uniqueSizeInBucket, + std::vector> &idCount) { int start = 0; int index = 0; - auto uIdInBucket = TypeTrans(uniqueIdInBucket); - auto uIds = TypeTrans(uniqueIds); + auto uIdInBucket = TypeTrans(conf.useSharding ? uniqueOut.uniqueIdInBucket : uniqueOut.uniqueId); + auto uIds = TypeTrans(uniqueOut.uniqueId); for (int i = 0; i < conf.shardingNum; i++) { GetIndexAndStart(uniqueSizeInBucket, conf.usePadding, i, start, index); @@ -419,35 +428,31 @@ private: if (conf.useIdCount && conf.usePadding) { memSize = uniqueSizeInBucket[i] * sizeof(int32_t); - rc = memcpy_s(idCntFill + start, memSize, idCnt + index, memSize); - ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCntFill]"); + rc = memcpy_s(uniqueOut.idCnt + index, memSize, (int32_t *)(idCount.data()) + start, + memSize); // 填充idCount + ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCnt]"); + } + + if (ret != 0) { + return ret; } + } + + if (conf.useIdCount) { + int ret = HandleIdCountFill(idCount, uniqueOut); if (ret != 0) { return ret; } } if (conf.usePadding) { - HandleFill(uIds, uniqueSizeInBucket, idCntFill); + HandleFill(uIds, uniqueSizeInBucket); } return H_OK; } - int PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg) - { - if (rc != 0) { - std::stringstream ssm; - ssm << "[" << logMsg << "] memcpy_s failed... dstSize: " << dstSize; - ExternalLogger::PrintLog(LogLevel::ERROR, ssm.str()); - return H_COPY_ERROR; - } else { - return H_OK; - } - } - - template - void HandleFill(typename Map::type *uIds, const int32_t *uniqueSizeInBucket, int32_t *idCntFill) + template void HandleFill(typename Map::type *uIds, const int32_t *uniqueSizeInBucket) { int start = 0; int index = 0; @@ -459,26 +464,6 @@ private: for (int j = 0; j < fillLen; j++) { uIds[start + uniqueSizeInBucket[i] + j] = conf.paddingVal; // padding填充 } - - if (idCntFill != nullptr) { - for (int y = 0; y < fillLen; y++) { - idCntFill[start + uniqueSizeInBucket[i] + y] = 0; // 特征计数填充 - } - } - } - } - - void GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, - int &index) - { - if (shardingNumber > 0) { - index += uniqueSizeInBucket[shardingNumber - 1]; - } - - if (usePadding) { - start = shardingNumber * conf.paddingSize; - } else { - start = index; } } @@ -493,13 +478,18 @@ private: tasks.push_back([this, val, start, end, &ret]() { for (uint64_t j = start; j < end; ++j) { auto value = val[j]; - if (value > conf.maxIdVal) { + if (UNLIKELY(value > conf.maxIdVal)) { ExternalLogger::PrintLog(LogLevel::ERROR, "id val is larger than maxIdVal"); ret = H_ID_LARGE; break; } - auto group = groupMethod_.GroupId(value); - dedupShards_[group]->Insert(value); + + if (conf.performance) { + dedupShards_[value & (conf.shardingNum - 1)]->Insert(value); + } else { + auto group = groupMethod_.GroupId(value); + dedupShards_[group]->Insert(value); + } } }); } @@ -520,31 +510,46 @@ private: } template - int CalUniqueOut(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut, std::vector &totalUniqueSize) + int CalUniqueOut(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut, std::vector &totalUniqueSize, + std::vector> &idCount) { uint32_t *beginPtr = uniqueOut.index; uint32_t *finishPtr = beginPtr + uniqueIn.inputIdCnt; uint32_t *partBeginPtr = beginPtr; - auto alignedAddress = CacheLineAlign(reinterpret_cast(partBeginPtr + partSize)); - auto *partEndPtr = reinterpret_cast(static_cast(alignedAddress)); + auto *partEndPtr = + reinterpret_cast(CacheLineAlign(reinterpret_cast(partBeginPtr + partSize))); std::vector> tasks; auto val = TypeTrans(uniqueIn.inputId); while (partBeginPtr < finishPtr) { if (partEndPtr > finishPtr) { partEndPtr = finishPtr; } - if (partBeginPtr < partEndPtr) { - // Due to cacheline alignment computation, the actual number of - // threads created here may not match threadNum exactly but - // should be +/-1 off. - tasks.push_back([this, val, beginPtr, partBeginPtr, partEndPtr, totalUniqueSize]() { - for (uint32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { + + if (partBeginPtr >= partEndPtr) { + partBeginPtr = partEndPtr; + partEndPtr += partSize; + continue; + } + + // Due to cacheline alignment computation, the actual number of + // threads created here may not match threadNum exactly but + // should be +/-1 off. + tasks.push_back([this, val, beginPtr, partBeginPtr, partEndPtr, totalUniqueSize, &idCount]() { + for (uint32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { + int32_t fillOffset; + if (conf.performance) { + fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], + val[ptr - beginPtr] & (conf.shardingNum - 1)); + } else { auto group = groupMethod_.GroupId(val[ptr - beginPtr]); - int32_t fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], group); - *ptr = fillOffset; + fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], group); } - }); - } + *ptr = fillOffset; + if (LIKELY(conf.useIdCount)) { + idCount[fillOffset]++; + } + } + }); partBeginPtr = partEndPtr; partEndPtr += partSize; } @@ -569,8 +574,10 @@ private: UniqueConf conf; std::vector> dedupShards_ {}; uint32_t partSize; - bool firstEnterFlag_ = false; + bool clearFinish_ = true; bool idCountEnable_ { false }; + ThreadPoolAsync pool_; + bool firstEnter_ = true; }; } } diff --git a/src/AccCTR/src/unique/unique_impl.cpp b/src/AccCTR/src/unique/unique_impl.cpp index 77113214..800f21de 100644 --- a/src/AccCTR/src/unique/unique_impl.cpp +++ b/src/AccCTR/src/unique/unique_impl.cpp @@ -228,6 +228,14 @@ int UniqueImpl::CheckEnhancedUniqueConf(const UniqueConf &conf) if (CheckInputZero(conf.shardingNum, "shardingNum")) { return H_NUM_SMALL; } + if (conf.performance) { + bool isExponentOfTwo = + (conf.shardingNum > 0) && ((conf.shardingNum & (conf.shardingNum - 1)) == 0); // 判断是不是2的N次幂 + if (!isExponentOfTwo) { + ExternalLogger::PrintLog(LogLevel::ERROR, "if performance is true, shardingNum must be 2^N"); + return H_ERROR; + } + } } return H_OK; diff --git a/src/AccCTR/src/unique/unique_impl.h b/src/AccCTR/src/unique/unique_impl.h index f4c45fde..e37a58db 100644 --- a/src/AccCTR/src/unique/unique_impl.h +++ b/src/AccCTR/src/unique/unique_impl.h @@ -43,7 +43,7 @@ private: private: ShardedDedup *unique = nullptr; - UniqueConf uniqueConf {}; + UniqueConf uniqueConf{}; }; } } diff --git a/src/AccCTR/tests/tools/create_fake_id.py b/src/AccCTR/tests/tools/create_fake_id.py index fc0f1f8e..aa42f071 100644 --- a/src/AccCTR/tests/tools/create_fake_id.py +++ b/src/AccCTR/tests/tools/create_fake_id.py @@ -68,12 +68,6 @@ def write_data(file_name, x, y, dup): def main(): - # 300w id去重率20% - # 6x + y =300 - # x + y = 60 - # x = 48 y =12 - write_data('data20.txt', 48*10000, 12*10000, 6) - # 300w id去重率30% # 6x + y =300 # x + y = 90 diff --git a/src/AccCTR/tests/ut/conf/toolchain.cmake b/src/AccCTR/tests/ut/conf/toolchain.cmake new file mode 100644 index 00000000..bd6617e4 --- /dev/null +++ b/src/AccCTR/tests/ut/conf/toolchain.cmake @@ -0,0 +1,24 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# 添加编译选项 +option(USE32BIT "Use 32-Bit" OFF) +if(USE32BIT) + add_compile_options(-m32) + add_link_options(-m32) +endif() + +add_compile_options(-Wall) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 11) \ No newline at end of file diff --git a/src/AccCTR/tests/ut/src/CMakeLists.txt b/src/AccCTR/tests/ut/src/CMakeLists.txt index a4c631e8..3da58244 100644 --- a/src/AccCTR/tests/ut/src/CMakeLists.txt +++ b/src/AccCTR/tests/ut/src/CMakeLists.txt @@ -19,6 +19,11 @@ set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) set(OCK_CTR_SRC_DIR ${PROJECT_SOURCE_DIR}/src) message("src" ${OCK_CTR_SRC_DIR}) +# 包含所有组件的cmake +include("${CMAKE_CURRENT_SOURCE_DIR}/../conf/toolchain.cmake") +set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../src) +set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../) + file(GLOB_RECURSE TEST_UNIQUE_FILES *.cpp *.h) add_executable(test_unique_files ${TEST_UNIQUE_FILES}) include_directories(${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/include) @@ -29,17 +34,36 @@ SET(LIB_3RD_GTEST ${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/lib64/lib message(${OCK_CTR_SRC_DIR}/include) +# 添加库文件的搜索路径 +target_link_directories(test_unique_files + PUBLIC + ${PROJECT_SOURCE_DIR}/output/ock_ctr_common/lib + ) +# 添加头文件的搜索路径 target_include_directories(test_unique_files PUBLIC - ${OCK_CTR_SRC_DIR}/include) + ${OCK_CTR_SRC_DIR}/include + ${PROJECT_SOURCE_DIR} + ${OCK_CTR_SRC_DIR}/common/util + ) +# 用来指定要链接的库 target_link_libraries(test_unique_files PUBLIC -Wl,--start-group + _ock_ctr_common pthread dl ${LIB_3RD_GTEST} ${LIB_3RD_GMOCK} -Wl,--end-group) +# 打印构建选项 +get_target_property(COMPILE_FLAGS test_unique_files COMPILE_OPTIONS) +get_target_property(LINK_FLAGS test_unique_files LINK_OPTIONS) +message(STATUS "Compiler id: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "Compile flags: ${COMPILE_FLAGS}") +message(STATUS "Link flags: ${LINK_FLAGS}") +message(STATUS "Build Type: ${CMAKE_BUILD_TYPE}") + diff --git a/src/AccCTR/tests/ut/src/common.h b/src/AccCTR/tests/ut/src/common.h new file mode 100644 index 00000000..7302d10c --- /dev/null +++ b/src/AccCTR/tests/ut/src/common.h @@ -0,0 +1,64 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef CTR_COMMON_H +#define CTR_COMMON_H +#include + +#include "factory.h" + +extern ock::ctr::FactoryPtr factory; + +enum CTRLogLevel { + DEBUG = 0, + INFO, + WARN, + ERROR, +}; + +class SimpleThreadPool { +public: + static void SyncRun(const std::vector> &tasks) + { + std::vector> futs; + for (auto &task : tasks) { + futs.push_back(std::async(task)); + } + for (auto &fut : futs) { + fut.wait(); + } + } +}; + +static void CTRLog(int level, const char *msg) +{ + switch (level) { + case CTRLogLevel::DEBUG: + std::cout << "DEBUG:" << msg << std::endl; + break; + case CTRLogLevel::INFO: + std::cout << "INFO:" << msg << std::endl; + break; + case CTRLogLevel::WARN: + std::cout << "WARN:" << msg << std::endl; + break; + case CTRLogLevel::ERROR: + std::cout << "ERROR:" << msg << std::endl; + break; + default: + break; + } +} + +#endif // CTR_COMMON_H diff --git a/src/AccCTR/tests/ut/src/emb_cache_test.cpp b/src/AccCTR/tests/ut/src/emb_cache_test.cpp new file mode 100644 index 00000000..582f6e68 --- /dev/null +++ b/src/AccCTR/tests/ut/src/emb_cache_test.cpp @@ -0,0 +1,1653 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include + +#include "common/util/error_code.h" +#include "emb_cache_test.h" +#include "common.h" + +FactoryPtr factory; +EmbCacheManagerPtr embCache = nullptr; + +std::vector GenKeys(uint64_t n, uint32_t seed = 0, uint64_t min = 0, uint64_t max = UINT64_MAX) +{ + std::mt19937 generator(seed); + std::uniform_int_distribution distribution(min, max); + std::vector data(n); + for (uint64_t &x : data) { + x = distribution(generator); + } + sort(data.begin(), data.end()); + data.erase(unique(data.begin(), data.end()), data.end()); + return data; +} + +std::vector GenUniqueKeys(uint64_t n) +{ + std::vector data(n); + for (uint64_t i = 0; i < n; i++) { + data[i] = i; + } + return data; +} + +EmbCacheManagerPtr EmbCacheTest::SimpleCreateTable(std::string tableName, uint32_t hostVocabSize, + uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, pair normalPara, + float constPara) +{ + factory->CreateEmbCacheManager(embCache); + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + EmbCache::NormalInitializerInfo normalInitializerInfo(normalPara.first, normalPara.second, 0, 1.0); + std::string normalInitializeName = "random_normal_initializer"; + EmbCache::InitializerInfo normalInitializeInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + EmbCache::ConstantInitializerInfo constantInitializerInfo(constPara, 1.0); + std::string constantInitializeName = "constant_initializer"; + + std::vector initializeInfos(extEmbeddingSize / embeddingSize); + initializeInfos[0] = normalInitializeInfo; + for (uint64_t i = 1; i < initializeInfos.size(); i++) { + initializeInfos[i] = EmbCache::InitializerInfo(constantInitializeName, embeddingSize * i, embeddingSize, + constantInitializerInfo); + } + int ret = embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize, 1); + if (ret != H_OK) { + string msg = "CreateCacheForTable Failed. ret: " + std::to_string(ret); + CTRLog(CTRLogLevel::ERROR, msg.c_str()); + return nullptr; + } + return embCache; +} + +EmbCacheManagerPtr EmbCacheTest::ConstZeroCreateTable(std::string tableName, uint32_t hostVocabSize, + uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, uint64_t prefillBufferSize, + uint8_t prefillThreadNum) +{ + factory->CreateEmbCacheManager(embCache); + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.0, 1.0); + std::string constantInitializeName = "constant_initializer"; + + std::vector initializeInfos = { EmbCache::InitializerInfo(constantInitializeName, 0, + extEmbeddingSize, constantInitializerInfo) }; + int ret = embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, prefillBufferSize, prefillThreadNum); + if (ret != H_OK) { + string msg = "CreateCacheForTable Failed. ret: " + std::to_string(ret); + CTRLog(CTRLogLevel::ERROR, msg.c_str()); + return nullptr; + } + return embCache; +} + +void EmbCacheTest::SetUpTestCase() +{ + Factory::Create(factory); + factory->SetExternalLogFuncInner(CTRLog); +} + +void EmbCacheTest::TearDownTestCase() {} + +void EmbCacheTest::SetUp() {} + +void EmbCacheTest::TearDown() +{ + if (embCache != nullptr) { + embCache->Destroy(); + embCache = nullptr; + } +} + +TEST_F(EmbCacheTest, ConstantInitializerInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========ConstantInitializerInfo start============="); + + // 正确初始化ConstantInitializerInfo结构体,无日志信息反馈 + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + CTRLog(CTRLogLevel::INFO, "===========ConstantInitializerInfo end============="); +} + +TEST_F(EmbCacheTest, NormalInitializerInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========NormalInitializerInfo start============="); + // 正确初始化NormalInitializerInfo结构体,无日志信息反馈 + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + // 标准差负值数学意义不明,传入负值问题用户自己承担 + EmbCache::NormalInitializerInfo normalInitializerInfo_ne_dev(0, -0.05, 0, 1.0); + CTRLog(CTRLogLevel::INFO, "===========NormalInitializerInfo end============="); +} + +TEST_F(EmbCacheTest, InitializerInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========InitializerInfo start============="); + uint32_t embeddingSize = 13; + + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + + // 传入的std::string不为"constant_initializer" 日志打印"Invalid Initializer Type." + std::string not_a_initializer_name = "not_a_initializer_name"; + EmbCache::InitializerInfo constantInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, constantInitializerInfo); + + // 传入的std::string不为"constant_initializer" 日志打印"Invalid Initializer Type." + not_a_initializer_name = ""; + constantInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, constantInitializerInfo); + + // 正确初始化InitializeInfo结构体,无日志信息反馈 + std::string constantInitializeName = "constant_initializer"; + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo); + + // 传入的std::string不为"random_normal_initializer"或truncated_normal_initializer 日志打印"Invalid Initializer + // Type." + not_a_initializer_name = "not_a_initializer_name"; + EmbCache::InitializerInfo normalInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, normalInitializerInfo); + + // 传入的std::string不为"random_normal_initializer"或truncated_normal_initializer 日志打印"Invalid Initializer + // Type." + not_a_initializer_name = ""; + normalInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, normalInitializerInfo); + + // 正确初始化InitializeInfo结构体,无日志信息反馈 + std::string normalInitializeName = "random_normal_initializer"; + normalInitializeInfo = EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + // 正确初始化InitializeInfo结构体,无日志信息反馈 + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + EmbCache::InitializerInfo truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, 0, embeddingSize, normalInitializerInfo); + + CTRLog(CTRLogLevel::INFO, "===========InitializerInfo end============="); +} + +TEST_F(EmbCacheTest, EmbCacheInfo) +{ + CTRLog(CTRLogLevel::INFO, "===========EmbCacheInfo start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + // 正确初始化EmbCacheInfo结构体,无日志信息反馈 + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + CTRLog(CTRLogLevel::INFO, "===========EmbCacheInfo end============="); +} + +TEST_F(EmbCacheTest, CreateCacheForTable) +{ + factory->CreateEmbCacheManager(embCache); + CTRLog(CTRLogLevel::INFO, "===========CreateCacheForTable start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, {}, -1, hostVocabSize), H_INITIALIZER_INVALID); + + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + std::string normalInitializeName = "random_normal_initializer"; + EmbCache::InitializerInfo normalInitializeInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + // 空initializer 日志打印出"Initializer is nullptr" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, { {}, {} }, -1, hostVocabSize), H_INITIALIZER_INVALID); + + normalInitializeInfo.initializer = nullptr; + // 空initializer 日志打印出"Initializer is nullptr" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, { normalInitializeInfo }, -1, hostVocabSize), + H_INITIALIZER_INVALID); + + normalInitializeInfo = EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + std::string constantInitializeName = "constant_initializer"; + EmbCache::InitializerInfo constantInitializeInfo(constantInitializeName, embeddingSize, embeddingSize + 1, + constantInitializerInfo); + std::vector initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + + // initializerInfos的区间之间有重叠或者遗漏 日志打印出"Initializers got coverage problems" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_INITIALIZER_INVALID); + + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize + 1, embeddingSize, constantInitializerInfo); + initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + // initializerInfos的区间之间有重叠或者遗漏 日志打印出"Initializers got coverage problems" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_INITIALIZER_INVALID); + + + embCacheInfo.extEmbeddingSize = extEmbeddingSize; + std::string not_a_initializer_name = "not_a_initializer_name"; + constantInitializeInfo = + EmbCache::InitializerInfo(not_a_initializer_name, embeddingSize, embeddingSize, constantInitializerInfo); + initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + + // 传入的Initializer的name不符要求 日志打印出"Invalid Initializer Type.\nInitializer is nullptr" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_INITIALIZER_INVALID); + + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize, constantInitializerInfo); + initializeInfos = { normalInitializeInfo, constantInitializeInfo }; + + embCacheInfo.extEmbeddingSize++; + + // 传入的embInfo中的传入的extEmbeddingSize并非embeddingSize的整数倍 日志打印出"extEmbeddingSize = embeddingSize + + // optimizerSize, which is divisible by embeddingSize" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), + H_EXT_EMBEDDING_SIZE_INVALID); + + embCacheInfo.maxCacheSize = 100; + // maxCacheSize>vocabSize 日志打印出"vocabSize must be greater than or equal to maxCacheSize" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), + H_HOST_VOCAB_SIZE_TOO_SMALL); + embCacheInfo.maxCacheSize = devVocabSize; + + embCacheInfo.extEmbeddingSize = 0; + // extEmbeddingSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.extEmbeddingSize = extEmbeddingSize; + + embCacheInfo.embeddingSize = 0; + // embeddingSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.embeddingSize = embeddingSize; + + embCacheInfo.vocabSize = 0; + // vocabSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.vocabSize = hostVocabSize; + + embCacheInfo.maxCacheSize = 0; + // maxCacheSize为0 日志打印出"size must be positive" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_SIZE_ZERO); + embCacheInfo.maxCacheSize = devVocabSize; + + embCacheInfo.tableName = ""; + // 传入的tableName空 日志打印出"tableName can not be empty" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_TABLE_NAME_EMPTY); + + embCacheInfo.tableName = + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "0000000001000000000100000000010001"; + // 传入的tableName长度正好为长度上限1024 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_OK); + + embCacheInfo.tableName = + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100012"; + // 传入的tableName长度为1025超过了长度上限 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_TABLE_NAME_TOO_LONG); + embCacheInfo.tableName = tableName; + + // 正常创建 日志中不会打印异常信息 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_OK); + + // 重复创建同名Table 日志打印出"This table has already been created" + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), + H_TABLE_CREATE_DUPLICATE); + embCache->Destroy(); + + // Destroy后仍能正常创建 日志中不会打印异常信息 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize), H_OK); + embCache->Destroy(); + + // prefill单线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 1), H_OK); + embCache->Destroy(); + + // prefill多线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 3), H_OK); + embCache->Destroy(); + + // prefill多线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 0), H_THREAD_NUM_ERROR); + embCache->Destroy(); + + // prefill过多线程 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 10000), H_THREAD_NUM_ERROR); + embCache->Destroy(); + + // prefill 正常buffersize + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 3, 1), H_OK); + embCache->Destroy(); + + // prefill 超大buffersize + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 10, 1), H_PREFILL_BUFFER_SIZE_INVALID); + embCache->Destroy(); + + // prefill 0buffersize + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, 0, 1), H_PREFILL_BUFFER_SIZE_INVALID); + CTRLog(CTRLogLevel::INFO, "===========CreateCacheForTable end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_ADDRS) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_ADDRS start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + std::vector addrs; + + lookupKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + // lookupkeys 为空 + lookupKeys = {}; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + lookupKeys = { 0 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs("not_a_table", lookupKeys, addrs), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tooLongTableName, lookupKeys, addrs), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 5 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_HOST_VOCAB_SIZE_TOO_SMALL); + + lookupKeys = { 5 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, 1), H_HOST_VOCAB_SIZE_TOO_SMALL); + + lookupKeys = { 0, 1, 4 }; + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + lookupKeys = { 0, 1, 4 }; + uint32_t threadNum = std::thread::hardware_concurrency(); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, threadNum + 1), H_THREAD_NUM_ERROR); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, threadNum), H_OK); + // 单线程lookup + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs, 0), H_THREAD_NUM_ERROR); + embCache->Destroy(); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_ADDRS_DATA) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_ADDRS_DATA start============="); + factory->CreateEmbCacheManager(embCache); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 39; + uint32_t devVocabSize = 100000; + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::string normalInitializeName = "random_normal_initializer"; + std::string constantInitializeName = "constant_initializer"; + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + // 加入所有初始化器的所有分支 + std::vector initializeInfos = { + EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo), + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, 0, normalInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize, constantInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, 2 * embeddingSize, 0, constantInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 2 * embeddingSize, embeddingSize, + normalInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 3 * embeddingSize, 0, normalInitializerInfo), + }; + // 正确创建 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos), H_OK); + std::vector lookupKeys; + std::vector addrs; + lookupKeys = GenKeys(hostVocabSize, 123321); + ASSERT_EQ(embCache->EmbeddingLookupAddrs(tableName, lookupKeys, addrs), H_OK); + + long double sum = 0.0; + long double cnt = 0.0; + long double accum = 0.0; + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + // normalInitializer 生成数据 + for (uint32_t j = 0; j < embeddingSize; j++) { + sum += addrs[i][j]; + cnt++; + } + + // constantInitializer 生成数据 + for (uint32_t j = embeddingSize; j < 2 * embeddingSize; j++) { + ASSERT_LE(std::abs(addrs[i][j] - 0.233), 1e-6f); + } + // truncatedNormalInitializer 生成数据 + for (uint32_t j = 2 * embeddingSize; j < 3 * embeddingSize; j++) { + // 在[-2*stddev, 2*stddev]范围中 + ASSERT_LE(std::abs(addrs[i][j]), 0.1f + 1e-6f); + } + } + + long double mean = sum / cnt; + for (uint32_t i = 0; i < lookupKeys.size(); ++i) { + for (uint32_t j = 0; j < embeddingSize; j++) { + accum += (addrs[i][j] - mean) * (addrs[i][j] - mean); + } + } + long double stdev = sqrt(accum / cnt); + ASSERT_LE(std::abs(mean), 5e-6f); + ASSERT_LE(std::abs(stdev - 0.05), 5e-6f); + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_ADDRS_DATA end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_300W) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_300W start============="); + factory->CreateEmbCacheManager(embCache); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 39; + uint32_t devVocabSize = 100000; + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::string normalInitializeName = "random_normal_initializer"; + std::string constantInitializeName = "constant_initializer"; + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.05, 0, 1.0); + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + // 加入所有初始化器的所有分支 + std::vector initializeInfos = { + EmbCache::InitializerInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo), + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, 0, normalInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize, constantInitializerInfo), + EmbCache::InitializerInfo(constantInitializeName, 2 * embeddingSize, 0, constantInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 2 * embeddingSize, embeddingSize, + normalInitializerInfo), + EmbCache::InitializerInfo(truncatedNormalInitializeName, 3 * embeddingSize, 0, normalInitializerInfo), + }; + // 正确创建 + ASSERT_EQ(embCache->CreateCacheForTable(embCacheInfo, initializeInfos), H_OK); + std::vector lookupKeys; + float *addr; + lookupKeys = GenKeys(hostVocabSize, 123321); + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + + long double sum = 0.0; + long double cnt = 0.0; + long double accum = 0.0; + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + // normalInitializer 生成数据 + for (uint32_t j = 0; j < embeddingSize; j++) { + sum += addr[i * extEmbeddingSize + j]; + cnt++; + } + + // constantInitializer 生成数据 + for (uint32_t j = embeddingSize; j < 2 * embeddingSize; j++) { + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - 0.233), 1e-6f); + } + // truncatedNormalInitializer 生成数据 + for (uint32_t j = 2 * embeddingSize; j < 3 * embeddingSize; j++) { + // 在[-2*stddev, 2*stddev]范围中 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j]), 0.1f + 1e-6f); + } + } + + long double mean = sum / cnt; + for (uint32_t i = 0; i < lookupKeys.size(); ++i) { + for (uint32_t j = 0; j < embeddingSize; j++) { + accum += (addr[i * extEmbeddingSize + j] - mean) * (addr[i * extEmbeddingSize + j] - mean); + } + } + long double stdev = sqrt(accum / cnt); + ASSERT_LE(std::abs(mean), 5e-6f); + ASSERT_LE(std::abs(stdev - 0.05), 5e-6f); + free(addr); + CTRLog(CTRLogLevel::INFO, "===========GenerateData end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_AND_REMOVE) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *addr; + + lookupKeys = { 0, 1, 2, 3, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_OK); + free(addr); + + // lookupkeys 为空 + lookupKeys = {}; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_OK); + free(addr); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookupAndRemove("not_a_table", lookupKeys, addr), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tooLongTableName, lookupKeys, addr), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_ADDRESS_NULL); + + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + uint32_t threadNum = std::thread::hardware_concurrency(); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, threadNum + 1), H_THREAD_NUM_ERROR); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, threadNum), H_OK); + // 单线程lookup + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 0), H_THREAD_NUM_ERROR); + free(addr); + embCache->Destroy(); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_AND_REMOVE_2) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_2 start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 200; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *addr; + + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 2; j++) { + lookupKeys.emplace_back(i); + } + } + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr), H_OK); + free(addr); + embCache->Destroy(); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_2 end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *addr; + + lookupKeys = { 0, 1, 2, 3, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + // lookupkeys 为空 + lookupKeys = {}; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookup("not_a_table", lookupKeys, addr), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->EmbeddingLookup(tooLongTableName, lookupKeys, addr), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 5 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_HOST_VOCAB_SIZE_TOO_SMALL); + free(addr); + + lookupKeys = { 0 }; + addr = nullptr; + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_ADDRESS_NULL); + + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + uint32_t threadNum = std::thread::hardware_concurrency(); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, threadNum + 1), H_THREAD_NUM_ERROR); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, threadNum), H_OK); + // 单线程lookup + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, 1), H_OK); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr, 0), H_THREAD_NUM_ERROR); + free(addr); + embCache->Destroy(); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_LOOKUP_AND_REMOVE_300W) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_300W start============="); + std::string tableName = "test_table"; + std::vector lookupKeys; + float *newEmb; + + // 300w个key + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 100000; + embCache = ConstZeroCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + lookupKeys = GenUniqueKeys(hostVocabSize); + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + newEmb[i * extEmbeddingSize + j] = i + 0.01f * j; // 生成特殊数据 + } + } + CTRLog(CTRLogLevel::INFO, "gen done"); + // 把特殊数据放到表中 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + CTRLog(CTRLogLevel::INFO, "EmbeddingUpdate done"); + + float *addr; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + // 查询特殊数据 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + CTRLog(CTRLogLevel::INFO, "EmbeddingLookup done"); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - (i + 0.01f * j)), 1e-6f); + } + } + free(addr); + addr = nullptr; + + // Remove之后再Lookup,观察这些embedding是不是被正确remove + // 首先确认EmbeddingLookupAndRemove不会报错 + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookupAndRemove(tableName, lookupKeys, addr, 4), H_OK); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - (i + 0.01f * j)), 1e-6f); + } + } + free(addr); + addr = nullptr; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + // 然后再lookup,并确保lookup不会报错 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 因为用const zero初始化, EmbeddingLookupAndRemove之后再lookup,结果应该全是0 + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - 0), 1e-6f); + } + } + free(addr); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_LOOKUP_AND_REMOVE_300W end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_UPDATE_300W) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE_300W start============="); + std::string tableName = "test_table"; + std::vector lookupKeys; + float *newEmb; + + // 300w个key + uint32_t hostVocabSize = 3000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 100000; + embCache = ConstZeroCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize, 50000, 6); + lookupKeys = GenKeys(hostVocabSize, 123321); + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + newEmb[i * extEmbeddingSize + j] = i + 0.01f * j; // 生成特殊数据 + } + } + CTRLog(CTRLogLevel::INFO, "gen done"); + // 把特殊数据放到表中 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + CTRLog(CTRLogLevel::INFO, "EmbeddingUpdate done"); + + float *addr; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + // 查询特殊数据 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + CTRLog(CTRLogLevel::INFO, "EmbeddingLookup done"); + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - (i + 0.01f * j)), 1e-6f); + } + } + // Remove之后再Lookup,观察这些embedding是不是被正确remove + // 首先确认remove不会报错 + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + // 然后再lookup,并确保lookup不会报错 + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 因为用const zero初始化, 删除之后再lookup,结果应该全是0 + for (uint32_t i = 0; i < lookupKeys.size(); i++) { + for (uint32_t j = 0; j < extEmbeddingSize; j++) { + // 验证表中数据正确性 + ASSERT_LE(std::abs(addr[i * extEmbeddingSize + j] - 0), 1e-6f); + } + } + free(addr); + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE_300W end============="); +} + +TEST_F(EmbCacheTest, EMBEDDING_UPDATE) +{ + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + float *newEmb; + + lookupKeys = { 0, 1, 2, 3, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + + // 更新存在的table,应当正常更新 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + lookupKeys = { 0 }; + newEmb = nullptr; + // 更新不存在的table + ASSERT_EQ(embCache->EmbeddingUpdate("not_a_table", lookupKeys, newEmb), H_TABLE_NOT_EXIST); + + // 表名超过上限 + ASSERT_EQ(embCache->EmbeddingUpdate(tooLongTableName, lookupKeys, newEmb), H_TABLE_NAME_TOO_LONG); + + lookupKeys = { 5 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + + // 当前embLocalTable中存储的key已达到hostVocabSize上限,并继续添加新key + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_HOST_VOCAB_SIZE_TOO_SMALL); + free(newEmb); + + lookupKeys = { 0 }; + newEmb = nullptr; + // 传入embAddr为空指针 + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_ADDRESS_NULL); + + // 更新存在于table的keys, 传入embAddr不为空指针 + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + // 线程数未超过核数 + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 4), H_OK); + free(newEmb); + + // 线程数等于核数 + uint32_t processCoreNum = std::thread::hardware_concurrency(); + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, processCoreNum), H_OK); + free(newEmb); + + // 线程数大于核数 + processCoreNum = std::thread::hardware_concurrency(); + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, processCoreNum + 1), H_THREAD_NUM_ERROR); + free(newEmb); + + // 线程数为0 + processCoreNum = std::thread::hardware_concurrency(); + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 0), H_THREAD_NUM_ERROR); + free(newEmb); + + // 线程数为1 + lookupKeys = { 0, 1, 4 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 1), H_OK); + free(newEmb); + + // lookupkeys为空 + lookupKeys = {}; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb, 1), H_OK); + free(newEmb); + + TearDown(); + + // 更新不存在于table的key,且当前embLocalTable中存储的key未达到hostVocabSize上限,继续添加新key + tableName = "test_table_one"; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + lookupKeys = { 0, 1 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb); + free(newEmb); + lookupKeys = { 2, 3 }; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + + CTRLog(CTRLogLevel::INFO, "===========EMBEDDING_UPDATE end============="); +} + +TEST_F(EmbCacheTest, GetSwapPairsAndKey2Offset) +{ + CTRLog(CTRLogLevel::INFO, "===========GetSwapPairsAndKey2Offset start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 100; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 10; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector insertKeys; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + + // 使用不存在的table + insertKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset("not_a_table", insertKeys, swapInKoPair, swapOutKoPair), + H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tooLongTableName, insertKeys, swapInKoPair, swapOutKoPair), + H_TABLE_NAME_TOO_LONG); + + // 正常查找不存在的keys + insertKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair, swapOutKoPair), H_OK); + bool ret1 = true; + for (uint64_t i = 0; i < swapInKoPair.first.size(); i++) { + if (swapInKoPair.first[i] != i) { + string msg = "the " + std::to_string(i) + "th has key " + std::to_string(swapInKoPair.first[i]) + + ", but expect " + std::to_string(i); + CTRLog(CTRLogLevel::INFO, msg.c_str()); + ret1 = false; + } + } + ASSERT_EQ(ret1, true); + + // 正常查找存在的keys + std::pair, std::vector> swapInKoPair2, swapOutKoPair2; + insertKeys = { 1, 2, 3 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair2, swapOutKoPair2), H_OK); + uint64_t uint_zero = 0; + ASSERT_EQ(swapInKoPair2.first.size(), uint_zero); + + std::pair, std::vector> swapInKoPair3, swapOutKoPair3; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + // 使用非空的koPair + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair, swapOutKoPair3), + H_ARG_NOT_EMPTY); + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair3, swapInKoPair), H_ARG_NOT_EMPTY); + // 存入keys正好达到maxCacheSize上限值 + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair3, swapOutKoPair3), H_OK); + + // 存入keys正好越过到maxCacheSize上限值 + std::pair, std::vector> swapInKoPair4, swapOutKoPair4; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair4, swapOutKoPair4), + H_MAX_CACHESIZE_TOO_SMALL); + + embCache->Destroy(); + // 单次存入keys超过maxCacheSize上限值 + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair5, swapOutKoPair5; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair5, swapOutKoPair5), + H_MAX_CACHESIZE_TOO_SMALL); + + embCache->Destroy(); + // 单次存入keys正好达到上限值后,再次查找已存在的keys + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair6, swapOutKoPair6; + insertKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair6, swapOutKoPair6), H_OK); + + embCache->Destroy(); + // 连续两次存入的keys未超过上限,第三次传入keys达到上限 + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair7, swapOutKoPair7; + insertKeys = { 0, 1, 2, 3, 4 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair7, swapOutKoPair7), H_OK); + + std::pair, std::vector> swapInKoPair8, swapOutKoPair8; + insertKeys = { 5, 6, 7, 8, 9 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair8, swapOutKoPair8), H_OK); + + std::pair, std::vector> swapInKoPair9, swapOutKoPair9; + insertKeys = { 10, 11, 12, 13, 14 }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair9, swapOutKoPair9), H_OK); + + embCache->Destroy(); + // 查询INVALID_KEY + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::pair, std::vector> swapInKoPair10, swapOutKoPair10; + uint64_t neg_one = -1; + insertKeys = { neg_one, neg_one, neg_one, neg_one, neg_one }; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair10, swapOutKoPair10), H_OK); + ASSERT_EQ(swapInKoPair10.first.empty(), true); + ASSERT_EQ(swapInKoPair10.second.empty(), true); + ASSERT_EQ(swapOutKoPair10.first.empty(), true); + ASSERT_EQ(swapOutKoPair10.second.empty(), true); + + // 查找空keys + std::pair, std::vector> swapInKoPair11, swapOutKoPair11; + insertKeys = {}; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, insertKeys, swapInKoPair11, swapOutKoPair11), H_OK); + ASSERT_EQ(swapInKoPair11.first.empty(), true); + ASSERT_EQ(swapInKoPair11.second.empty(), true); + ASSERT_EQ(swapOutKoPair11.first.empty(), true); + ASSERT_EQ(swapOutKoPair11.second.empty(), true); + CTRLog(CTRLogLevel::INFO, "===========GetSwapPairsAndKey2Offset end============="); +} + + +bool checkKeys(std::set &keySet, std::vector> &historyKeyVec, + const std::vector &keys, const std::vector &swapInKeys, + const std::vector &swapOutKeys, uint32_t maxCacheSize) +{ + std::set newKeys; + for (auto key : keys) { + if (keySet.find(key) == keySet.end()) { + newKeys.insert(key); + } + keySet.insert(key); + } + for (auto key : swapInKeys) { + if (newKeys.find(key) == newKeys.end()) { + CTRLog(CTRLogLevel::ERROR, "swapIn key error1"); + return false; + } + } + if (swapInKeys.size() != newKeys.size()) { + CTRLog(CTRLogLevel::ERROR, "swapIn key error2"); + return false; + } + historyKeyVec.insert(historyKeyVec.begin(), { keys.begin(), keys.end() }); + if (historyKeyVec.size() > 2) { + historyKeyVec.pop_back(); + } + for (auto key : swapOutKeys) { + if (historyKeyVec[0].find(key) != historyKeyVec[0].end() || + historyKeyVec[1].find(key) != historyKeyVec[1].end()) { + CTRLog(CTRLogLevel::ERROR, "swapOut key error1"); + return false; + } + } + for (auto key : swapOutKeys) { + if (keySet.find(key) == keySet.end()) { + CTRLog(CTRLogLevel::ERROR, "swapOut key error2"); + return false; + } + } + for (auto key : swapOutKeys) { + keySet.erase(key); + } + if (keySet.size() > maxCacheSize) { + CTRLog(CTRLogLevel::ERROR, "total key size error"); + return false; + } + return true; +} + +bool checkOffsets(std::set &offsetSet, const std::vector &swapInOffsets, + const std::vector &swapOutOffset) +{ + for (auto offset : swapOutOffset) { + if (offsetSet.find(offset) == offsetSet.end()) { + CTRLog(CTRLogLevel::ERROR, "swapOut offset error1"); + return false; + } + } + + for (auto offset : swapOutOffset) { + offsetSet.erase(offset); + } + + for (auto offset : swapInOffsets) { + if (offsetSet.find(offset) != offsetSet.end()) { + CTRLog(CTRLogLevel::ERROR, "swapIn offset error"); + return false; + } + offsetSet.insert(offset); + } + + return true; +} + + +TEST_F(EmbCacheTest, DEVICE_COMBINE_TEST) +{ + CTRLog(CTRLogLevel::INFO, "===========DEVICE_COMBINE_TEST start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 4000000; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 30000; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::set keySet; + std::set offsetSet; + std::vector> historyKeyVec; + std::vector> historyOffsetVec; + std::vector lookupKeys; + std::vector check_keys; + for (uint32_t i = 0; i < 50; i++) { + lookupKeys = GenKeys(10000, 123 + i, 0, 100000); + check_keys = lookupKeys; + std::pair, std::vector> koPair1; + std::pair, std::vector> koPair2; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, koPair1, koPair2), H_OK); + bool retKey1 = checkKeys(keySet, historyKeyVec, check_keys, koPair1.first, koPair2.first, devVocabSize); + bool retOffset1 = checkOffsets(offsetSet, koPair1.second, koPair2.second); + ASSERT_EQ(retKey1, true); + ASSERT_EQ(retOffset1, true); + } + + CTRLog(CTRLogLevel::INFO, "===========DEVICE_COMBINE_TEST end============="); +} + +TEST_F(EmbCacheTest, REMOVE_KEYS) +{ + CTRLog(CTRLogLevel::INFO, "===========REMOVE_KEYS start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 100; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 10; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + std::vector removeKeys; + float *addr; + float *newEmb; + + for (uint32_t i = 0; i < hostVocabSize - 1; i++) { + lookupKeys.emplace_back(i); + for (uint32_t j = 0; j < hostVocabSize - 1; j++) { + removeKeys.emplace_back(i + j); + } + } + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + // 表存在 + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + + // 表不存在 + ASSERT_EQ(embCache->RemoveEmbsByKeys("not_a_table", lookupKeys), H_TABLE_NOT_EXIST); + + // 表名超过上限 + ASSERT_EQ(embCache->RemoveEmbsByKeys(tooLongTableName, lookupKeys), H_TABLE_NAME_TOO_LONG); + + // remove INVALID_KEY + uint64_t neg_one = -1; + lookupKeys = { neg_one, neg_one, neg_one, neg_one, neg_one }; + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + + // 判断embLocalTable是否remove掉记录信息 + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 999.99f; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret1 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) > 0.0000001) { + ret1 = false; + } + } + free(addr); + ASSERT_EQ(ret1, true); + + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, lookupKeys), H_OK); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret2 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) <= 0.0000001) { + ret2 = false; + } + } + free(addr); + ASSERT_EQ(ret2, true); + + // 判断offsetMapper是否remove掉记录信息 + lookupKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair, swapOutKoPair), H_OK); + removeKeys = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, removeKeys), H_OK); + std::vector> koVec; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tableName, koVec), H_OK); + bool ret3 = true; + for (uint32_t i = 0; i < koVec.size(); i++) { + if (std::find(removeKeys.begin(), removeKeys.end(), koVec[i].first) != removeKeys.end()) { + ret3 = false; + } + } + ASSERT_EQ(ret3, true); + // 判断删除后,还能再添加 + lookupKeys = { 9, 10, 11, 12, 13 }; + std::vector oldKeys = lookupKeys; + std::pair, std::vector> swapInKoPair2, swapOutKoPair2; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair2, swapOutKoPair2), H_OK); + bool ret4 = true; + for (uint32_t i = 0; i < 5; i++) { + if (oldKeys[i] != swapInKoPair2.first[i]) { + ret4 = false; + } + } + bool ret5 = true; + for (uint32_t i = 0; i < 5; i++) { + if (lookupKeys[i] != swapInKoPair2.second[i]) { + ret5 = false; + } + } + ASSERT_EQ(ret4, true); + ASSERT_EQ(ret5, true); + ASSERT_EQ(swapInKoPair2.first.size(), 5ull); + ASSERT_EQ(swapInKoPair2.second.size(), 5ull); + ASSERT_EQ(swapOutKoPair2.first.empty(), true); + ASSERT_EQ(swapOutKoPair2.second.empty(), true); + + removeKeys = { 9, 10, 11, 3 }; + ASSERT_EQ(embCache->RemoveEmbsByKeys(tableName, removeKeys), H_OK); + std::vector> koVec2; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tableName, koVec2), H_OK); + bool ret6 = true; + for (uint32_t i = 0; i < koVec2.size(); i++) { + if (std::find(removeKeys.begin(), removeKeys.end(), koVec2[i].first) != removeKeys.end()) { + ret6 = false; + } + } + ASSERT_EQ(ret6, true); + + // 判断删除后,还能再添加 + lookupKeys = { 0, 1, 2, 3, 4, 5, 6, 7 }; + std::vector oldKeys2 = lookupKeys; + std::pair, std::vector> swapInKoPair3, swapOutKoPair3; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair3, swapOutKoPair3), H_OK); + bool ret7 = true; + for (uint32_t i = 0; i < 8; i++) { + if (oldKeys2[i] != swapInKoPair3.first[i]) { + ret7 = false; + } + } + bool ret8 = true; + for (uint32_t i = 0; i < 8; i++) { + if (lookupKeys[i] != swapInKoPair3.second[i]) { + ret8 = false; + } + } + ASSERT_EQ(ret7, true); + ASSERT_EQ(ret8, true); + ASSERT_EQ(swapInKoPair3.first.size(), 8ull); + ASSERT_EQ(swapInKoPair3.second.size(), 8ull); + ASSERT_EQ(swapOutKoPair3.first.empty(), true); + ASSERT_EQ(swapOutKoPair3.second.empty(), true); + + lookupKeys = { 15 }; + std::pair, std::vector> swapInKoPair4, swapOutKoPair4; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair4, swapOutKoPair4), + H_OK); + + CTRLog(CTRLogLevel::INFO, "===========REMOVE_KEYS end============="); +} + +TEST_F(EmbCacheTest, ExportDeviceKeyOffsetPairs) +{ + CTRLog(CTRLogLevel::INFO, "===========ExportDeviceKeyOffsetPairs start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 10; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 8; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + // 使用不存在的table名字 + std::vector> koVec; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs("not_a_table", koVec), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tooLongTableName, koVec), H_TABLE_NAME_TOO_LONG); + + // 正常export出koPair + std::vector lookupKeys; + std::vector checkKeys; + lookupKeys = { 6, 0, 8, 1, 3, 4 }; + checkKeys = lookupKeys; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair, swapOutKoPair), H_OK); + std::vector> koVec2; + ASSERT_EQ(embCache->ExportDeviceKeyOffsetPairs(tableName, koVec2), H_OK); + ASSERT_EQ(koVec2.size(), lookupKeys.size()); + bool ret1 = true; + for (uint32_t i = 0; i < koVec2.size(); i++) { + if (koVec2[i].first != checkKeys[i] || koVec2[i].second != lookupKeys[i]) { + ret1 = false; + } + } + ASSERT_EQ(ret1, true); + + CTRLog(CTRLogLevel::INFO, "===========ExportDeviceKeyOffsetPairs end============="); +} + +TEST_F(EmbCacheTest, GetEmbTableNames) +{ + CTRLog(CTRLogLevel::INFO, "===========GetEmbTableNames start============="); + factory->CreateEmbCacheManager(embCache); + uint32_t hostVocabSize = 10; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 8; + std::vector tableNameVec; + tableNameVec.emplace_back("table1"); + tableNameVec.emplace_back("table2"); + tableNameVec.emplace_back("table3"); + for (const std::string tableName : tableNameVec) { + EmbCache::EmbCacheInfo embCacheInfo(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + EmbCache::NormalInitializerInfo normalInitializerInfo(0, 0.5, 0, 1.0); + std::string normalInitializeName = "random_normal_initializer"; + EmbCache::InitializerInfo normalInitializeInfo(normalInitializeName, 0, embeddingSize, normalInitializerInfo); + + EmbCache::ConstantInitializerInfo constantInitializerInfo(0.233, 1.0); + std::string constantInitializeName = "constant_initializer"; + EmbCache::InitializerInfo constantInitializeInfo(constantInitializeName, embeddingSize, embeddingSize, + constantInitializerInfo); + + std::vector initializeInfos(extEmbeddingSize / embeddingSize); + initializeInfos[0] = normalInitializeInfo; + for (uint64_t i = 1; i < initializeInfos.size(); i++) { + initializeInfos[i] = constantInitializeInfo; + } + embCache->CreateCacheForTable(embCacheInfo, initializeInfos, -1, hostVocabSize); + } + std::vector allTableNames; + std::vector notEmptyVector = { "123" }; + ASSERT_EQ(embCache->GetEmbTableNames(notEmptyVector), H_ARG_NOT_EMPTY); + + ASSERT_EQ(embCache->GetEmbTableNames(allTableNames), H_OK); + bool ret1 = true; + for (auto tableName : allTableNames) { + if (std::find(tableNameVec.begin(), tableNameVec.end(), tableName) == tableNameVec.end()) { + ret1 = false; + } + } + for (auto tableName : tableNameVec) { + if (std::find(allTableNames.begin(), allTableNames.end(), tableName) == allTableNames.end()) { + ret1 = false; + } + } + ASSERT_EQ(ret1, true); + + CTRLog(CTRLogLevel::INFO, "===========GetEmbTableNames end============="); +} + +TEST_F(EmbCacheTest, SERIALIZE) +{ + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + + lookupKeys = { 0 }; + std::vector buffer; + ASSERT_EQ(embCache->Serialize("not_a_table", buffer), H_TABLE_NOT_EXIST); + // 表名超过上限 + ASSERT_EQ(embCache->Serialize(tooLongTableName, buffer), H_TABLE_NAME_TOO_LONG); + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE end============="); +} + +TEST_F(EmbCacheTest, DESERIALIZE) +{ + CTRLog(CTRLogLevel::INFO, "===========DESERIALIZE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + + lookupKeys = { 0 }; + std::vector buffer = { 'A', 'B', '1', '2' }; + ASSERT_EQ(embCache->Deserialize("not_a_table", buffer), H_TABLE_NOT_EXIST); + + ASSERT_EQ(embCache->Deserialize(tooLongTableName, buffer), H_TABLE_NAME_TOO_LONG); + + ASSERT_EQ(embCache->Deserialize(tableName, buffer), H_BUFFER_INVALID); + + lookupKeys = { 0, 1, 2, 3, 4 }; + float *newEmb; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + std::vector buffer1; + ASSERT_EQ(embCache->Serialize(tableName, buffer1), H_OK); + buffer1.erase(buffer1.begin() + buffer1.size() / 2, buffer1.end()); + ASSERT_EQ(embCache->Deserialize(tableName, buffer1), H_BUFFER_INVALID); + + CTRLog(CTRLogLevel::INFO, "===========DESERIALIZE end============="); +} + +TEST_F(EmbCacheTest, SERIALIZE_DESERIALIZE) +{ + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE_DESERIALIZE start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 5; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 2; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + + std::vector lookupKeys; + lookupKeys = { 0, 1, 2, 3, 4 }; + float *newEmb; + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 0.01f * i; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + std::vector buffer1; + std::vector buffer2; + + ASSERT_EQ(embCache->Serialize(tableName, buffer1), H_OK); + ASSERT_EQ(embCache->Deserialize(tableName, buffer1), H_OK); + ASSERT_EQ(embCache->Serialize(tableName, buffer2), H_OK); + ASSERT_EQ(buffer1.size(), buffer2.size()); + for (uint64_t i = 0; i < buffer1.size(); i++) { + ASSERT_EQ(buffer1[i], buffer2[i]); + } + ASSERT_EQ(buffer1, buffer2); + CTRLog(CTRLogLevel::INFO, "===========SERIALIZE_DESERIALIZE end============="); +} + +TEST_F(EmbCacheTest, ERROR_INITIALIZER) +{ + CTRLog(CTRLogLevel::INFO, "===========ERROR_INITIALIZER start============="); + uint32_t embeddingSize = 13; + /* 对ConstantInitializerInfo的constValue和initK的校验 */ + std::string constantInitializeName = "constant_initializer"; + // 日志打印"constant value is less than -1000000000, and will use -1000000000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo1(-1e9 - 1e8, 1.0); + EmbCache::InitializerInfo constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo1); + + // 日志打印"constant value is greater than 1000000000, and will use 1000000000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo2(1e9 + 1e8, 1.0); + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo2); + + // 日志打印"constant initK is greater than 10000, and will use 10000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo3(0.233, 10001); + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo3); + + // 日志打印"constant initK is less than -10000, and will use -10000.",并正常初始化InitializerInfo + EmbCache::ConstantInitializerInfo constantInitializerInfo4(0.233, -10001); + constantInitializeInfo = + EmbCache::InitializerInfo(constantInitializeName, embeddingSize, embeddingSize + 1, constantInitializerInfo4); + + /* 对NormalIntializerInfo的mean、stdev和initK的校验 */ + std::string normalInitializeName = "random_normal_initializer"; + // 日志打印"random normal mean param is greater than 1000000000, and will use + // 1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo1(1e9 + 1e8, 0.05, 0, 1.0); + EmbCache::InitializerInfo normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo1); + + // 日志打印"random normal mean param is less than -1000000000, and will use + // -1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo2(-1e9 - 1e8, 0.05, 0, 1.0); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo2); + + // 日志打印"random normal stddev param is greater than 100, and will use 100.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo3(0, 101, 0, 1.0); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo3); + + // 日志打印"random normal stddev param is less than 0, and will use 0.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo4(0, -1, 0, 1.0); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo4); + // 日志打印"random normal initK is greater than 10000, and will use 10000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo5(0, 0.05, 0, 10001); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo5); + // 日志打印"random normal initK is less than -10000, and will use -10000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo6(0, 0.05, 0, -10001); + normalInitializeInfo = + EmbCache::InitializerInfo(normalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo6); + + /* 对TruncatedNormalInitializer的mean、stdev以及initK的校验 */ + std::string truncatedNormalInitializeName = "truncated_normal_initializer"; + // 日志打印"truncated normal mean param is greater than 1000000000, and will use + // 1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo7(1e9 + 1e8, 0.05, 0, 1.0); + EmbCache::InitializerInfo truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo7); + + // 日志打印"truncated normal mean param is less than -1000000000, and will use + // -1000000000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo8(-1e9 - 1e8, 0.05, 0, 1.0); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo8); + + // 日志打印"truncated normal stddev param is greater than 100, and will use 100.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo9(0, 101, 0, 1.0); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo9); + + // 日志打印"truncated normal stddev param is less than 0.000000, and will use 0.000000."并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo10(0, -1, 0, 1.0); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo10); + // 日志打印"truncated normal initK is greater than 10000, and will use 10000.",并正常初始化InitializerInfo + EmbCache::NormalInitializerInfo normalInitializerInfo11(0, 0.05, 0, 10001); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo11); + // 日志打印"truncated normal initK is less than -10000, and will use -10000." + EmbCache::NormalInitializerInfo normalInitializerInfo12(0, 0.05, 0, -10001); + truncatedNormalInitializeInfo = + EmbCache::InitializerInfo(truncatedNormalInitializeName, embeddingSize, embeddingSize, normalInitializerInfo12); + CTRLog(CTRLogLevel::INFO, "===========ERROR_INITIALIZER end============="); +} + + +TEST_F(EmbCacheTest, EmbeddingRemove) +{ + CTRLog(CTRLogLevel::INFO, "===========EmbeddingRemove start============="); + std::string tableName = "test_table"; + uint32_t hostVocabSize = 100; + uint32_t embeddingSize = 13; + uint32_t extEmbeddingSize = 26; + uint32_t devVocabSize = 100; + embCache = SimpleCreateTable(tableName, hostVocabSize, embeddingSize, extEmbeddingSize, devVocabSize); + std::vector lookupKeys; + std::vector removeKeys; + float *addr; + float *newEmb; + + for (uint32_t i = 0; i < hostVocabSize - 1; i++) { + lookupKeys.emplace_back(i); + for (uint32_t j = 0; j < hostVocabSize - 1; j++) { + removeKeys.emplace_back(i + j); + } + } + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 表存在 + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys), H_OK); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + // 单线程 + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys, 1), H_OK); + + free(addr); + // REMOVE空keys + std::vector emptyRemoveKeys; + ASSERT_EQ(embCache->EmbeddingRemove(tableName, emptyRemoveKeys), H_OK); + + // 表不存在 + ASSERT_EQ(embCache->EmbeddingRemove("not_a_table", lookupKeys), H_TABLE_NOT_EXIST); + // 表名超过上限 + ASSERT_EQ(embCache->EmbeddingRemove(tooLongTableName, lookupKeys), H_TABLE_NAME_TOO_LONG); + + // remove INVALID_KEY + uint64_t neg_one = -1; + lookupKeys = { neg_one, neg_one, neg_one, neg_one, neg_one }; + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys), H_OK); + + // 判断embLocalTable是否remove掉记录信息 + lookupKeys = { 0, 1, 4 }; + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + free(addr); + + newEmb = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + newEmb[i] = 999.99f; + } + ASSERT_EQ(embCache->EmbeddingUpdate(tableName, lookupKeys, newEmb), H_OK); + free(newEmb); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret1 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) > 0.0000001) { + ret1 = false; + } + } + free(addr); + ASSERT_EQ(ret1, true); + + ASSERT_EQ(embCache->EmbeddingRemove(tableName, lookupKeys), H_OK); + + addr = (float *)malloc(lookupKeys.size() * extEmbeddingSize * sizeof(float)); + ASSERT_EQ(embCache->EmbeddingLookup(tableName, lookupKeys, addr), H_OK); + bool ret2 = true; + for (uint32_t i = 0; i < lookupKeys.size() * extEmbeddingSize; i++) { + if (fabs(addr[i] - 999.99f) <= 0.0000001) { + ret2 = false; + } + } + free(addr); + ASSERT_EQ(ret2, true); + + // 判断offsetMapper是否remove掉记录信息 + lookupKeys = { 6, 0, 8, 1, 3, 4 }; + std::pair, std::vector> swapInKoPair, swapOutKoPair; + ASSERT_EQ(embCache->GetSwapPairsAndKey2Offset(tableName, lookupKeys, swapInKoPair, swapOutKoPair), H_OK); + removeKeys = { 0, 1, 4 }; + ASSERT_EQ(embCache->EmbeddingRemove(tableName, removeKeys), H_OK); + + CTRLog(CTRLogLevel::INFO, "===========EmbeddingRemove end============="); +} diff --git a/src/AccCTR/tests/ut/src/emb_cache_test.h b/src/AccCTR/tests/ut/src/emb_cache_test.h new file mode 100644 index 00000000..e8e2837d --- /dev/null +++ b/src/AccCTR/tests/ut/src/emb_cache_test.h @@ -0,0 +1,62 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef CTR_EMB_CACHE_TEST_H +#define CTR_EMB_CACHE_TEST_H + +#include +#include +#include +#include +#include "factory.h" +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; +using namespace ock::ctr; + +class EmbCacheTest : public testing::Test { +protected: + EmbCacheTest(){}; + ~EmbCacheTest(){}; + static void SetUpTestCase(); + static void TearDownTestCase(); + + + void SetUp() override; + + void TearDown() override; + + static EmbCacheManagerPtr SimpleCreateTable(std::string tableName, uint32_t hostVocabSize, uint32_t embeddingSize, + uint32_t extEmbeddingSize, uint32_t devVocabSize, pair normalPara = { 0, 0.05 }, + float constPara = 0.233); + + static EmbCacheManagerPtr ConstZeroCreateTable(std::string tableName, uint32_t hostVocabSize, + uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, uint64_t prefillBufferSize = 50000, + uint8_t prefillThreadNum = 1); + + std::string tooLongTableName = + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100000000010000000001000000000100000000010000000001000000000100000000010000000001" + "00000000010000000001000000000100012"; +}; + +#endif // CTR_EMB_CACHE_TEST_H diff --git a/src/AccCTR/tests/ut/src/unique_test.cpp b/src/AccCTR/tests/ut/src/unique_test.cpp index f971bb91..a94ebaf7 100644 --- a/src/AccCTR/tests/ut/src/unique_test.cpp +++ b/src/AccCTR/tests/ut/src/unique_test.cpp @@ -15,8 +15,7 @@ limitations under the License. #include #include #include "unique_test.h" - -FactoryPtr factory; +#include "common.h" void UniqueTest::SetUpTestCase() { @@ -144,7 +143,10 @@ TEST_F(UniqueTest, DoUniqueNormal) std::string input_path(path); std::cout << "input_path:" + input_path + "/data30.txt" << std::endl; std::ifstream input(input_path + "/data30.txt"); - + if(!input.good()) { + std::cout << "Failed to open file:" + input_path + "/data30.txt" << std::endl; + return; + } std::vector numbers; std::string line; while (std::getline(input, line, ',')) { @@ -156,6 +158,8 @@ TEST_F(UniqueTest, DoUniqueNormal) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.trace = true; conf.desiredSize = numbers.size(); @@ -213,6 +217,8 @@ TEST_F(UniqueTest, UseErrOutputTypeEnhanced) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -253,6 +259,8 @@ TEST_F(UniqueTest, UseErrOutputTypeNormal) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -292,6 +300,8 @@ TEST_F(UniqueTest, DoEnhancedUnique) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -340,6 +350,8 @@ TEST_F(UniqueTest, DoEnhancedUniqueErr) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -402,6 +414,8 @@ TEST_F(UniqueTest, DoEnhancedUnique_UniqueIdSize) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -449,6 +463,8 @@ TEST_F(UniqueTest, idCntIsNull) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -488,6 +504,8 @@ TEST_F(UniqueTest, idCntIsNullSharding) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -537,6 +555,8 @@ TEST_F(UniqueTest, DoUniqueShard) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.useIdCount = true; @@ -612,6 +632,8 @@ TEST_F(UniqueTest, DoUniqueOnlyShard) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -675,6 +697,8 @@ TEST_F(UniqueTest, DoUniquePadding) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.usePadding = true; conf.useSharding = true; @@ -755,6 +779,8 @@ TEST_F(UniqueTest, DoUniqueNoThreadPool) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 20; // 配置空间大于实际输入数组长度,验证正常运行 conf.dataType = DataType::INT64; @@ -817,6 +843,8 @@ TEST_F(UniqueTest, DoUniqueShardNumberOversize) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -895,6 +923,7 @@ TEST_F(UniqueTest, DoUniqueSpecial) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); int count = 1000000; UniqueConf conf; @@ -963,6 +992,8 @@ TEST_F(UniqueTest, IdLarge) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -999,6 +1030,8 @@ TEST_F(UniqueTest, DoUniqueNormalInt32) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -1122,6 +1155,8 @@ TEST_F(UniqueTest, DoUniqueShardMultipleTimes) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.useSharding = true; conf.desiredSize = 6; @@ -1286,6 +1321,8 @@ TEST_F(UniqueTest, IdCntSmall) UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); + factory->SetExternalLogFuncInner(CTRLog); + UniqueConf conf; conf.desiredSize = 6; conf.dataType = DataType::INT64; @@ -1321,7 +1358,10 @@ TEST_F(UniqueTest, DoUniqueLotsDataFunction) std::string input_path(path); std::cout << "input_path:" + input_path + "/data40.txt" << std::endl; std::ifstream input(input_path + "/data40.txt"); - + if(!input.good()) { + std::cout << "Failed to open file:" + input_path + "/data40.txt" << std::endl; + return; + } std::vector numbers; std::string line; while (std::getline(input, line, ',')) { @@ -1423,7 +1463,10 @@ TEST_F(UniqueTest, DoUniqueLotsDataPaddingFunction) std::string input_path(path); std::cout << "input_path:" + input_path + "/data30.txt" << std::endl; std::ifstream input(input_path + "/data30.txt"); - + if(!input.good()) { + std::cout << "Failed to open file:" + input_path + "/data30.txt" << std::endl; + return; + } std::vector numbers; std::string line; while (std::getline(input, line, ',')) { diff --git a/src/AccCTR/tests/ut/src/unique_test.h b/src/AccCTR/tests/ut/src/unique_test.h index 0243f262..c3bc64f3 100644 --- a/src/AccCTR/tests/ut/src/unique_test.h +++ b/src/AccCTR/tests/ut/src/unique_test.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include "factory.h" #include "gtest/gtest.h" #include "gmock/gmock.h" #include "unique.h" @@ -28,21 +27,6 @@ using namespace std; using namespace ock::ctr; -class SimpleThreadPool { -public: - static void SyncRun(const std::vector> &tasks) - { - std::vector> futs; - for (auto &task : tasks) { - futs.push_back(std::async(task)); - } - for (auto &fut : futs) { - fut.wait(); - } - } -}; - - class UniqueTest : public testing::Test { protected: UniqueTest() {}; -- Gitee From 101385e69fa25c4d5667dfc84b94ddc3b533bd95 Mon Sep 17 00:00:00 2001 From: yangzhen Date: Fri, 10 May 2024 11:59:50 +0800 Subject: [PATCH 2/7] cleancode --- .../offset_mapper/mapper_base.h | 74 ++++++++++--------- .../offset_mapper/offset_mapper.h | 2 + src/AccCTR/tests/ut/src/emb_cache_test.cpp | 3 + src/AccCTR/tests/ut/src/emb_cache_test.h | 14 ++-- 4 files changed, 52 insertions(+), 41 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h index 363c59ee..6f4debb6 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -37,6 +37,12 @@ namespace EmbCache { static constexpr size_t K_ALIGNMENT = 64; static constexpr size_t K_KVNUMINBUCKET = 3; +enum BucketIdx { + first, + second, + third +}; + class NetHeapAllocator { public: void *Allocate(uint32_t size) @@ -94,17 +100,17 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldKey = 0; - if (keys[0].load(std::memory_order_relaxed) == 0 && keys[0].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::first].load(std::memory_order_relaxed) == 0 && keys[BucketIdx::first].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { - keys[0] = 0; + keys[BucketIdx::first] = 0; return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; } if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { - keys[0] = 0; + keys[BucketIdx::first] = 0; return FkvState::FKV_NO_SPACE; } - values[0] = value; + values[BucketIdx::first] = value; return FkvState::FKV_NOT_EXIST; } @@ -113,17 +119,17 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldKey = 0; - if (keys[1].load(std::memory_order_relaxed) == 0 && keys[1].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::second].load(std::memory_order_relaxed) == 0 && keys[BucketIdx::second].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { - keys[1] = 0; + keys[BucketIdx::second] = 0; return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; } if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { - keys[1] = 0; + keys[BucketIdx::second] = 0; return FkvState::FKV_NO_SPACE; } - values[1] = value; + values[BucketIdx::second] = value; return FkvState::FKV_NOT_EXIST; } @@ -132,17 +138,17 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldKey = 0; - if (keys[2].load(std::memory_order_relaxed) == 0 && keys[2].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::third].load(std::memory_order_relaxed) == 0 && keys[BucketIdx::third].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { - keys[2] = 0; + keys[BucketIdx::third] = 0; return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; } if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { - keys[2] = 0; + keys[BucketIdx::third] = 0; return FkvState::FKV_NO_SPACE; } - values[2] = value; + values[BucketIdx::third] = value; return FkvState::FKV_NOT_EXIST; } @@ -161,18 +167,18 @@ struct alignas(K_ALIGNMENT)NetHashBucket { /* * expand the loop, instead of put them into a for/while loop for performance */ - if (key == keys[0].load(std::memory_order_relaxed)) { - value = values[0]; + if (key == keys[BucketIdx::first].load(std::memory_order_relaxed)) { + value = values[BucketIdx::first]; return true; } - if (key == keys[1].load(std::memory_order_relaxed)) { - value = values[1]; + if (key == keys[BucketIdx::second].load(std::memory_order_relaxed)) { + value = values[BucketIdx::second]; return true; } - if (key == keys[2].load(std::memory_order_relaxed)) { - value = values[2]; + if (key == keys[BucketIdx::third].load(std::memory_order_relaxed)) { + value = values[BucketIdx::third]; return true; } @@ -183,8 +189,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldValue = key; - if (keys[0].load(std::memory_order_relaxed) == key && keys[0].compare_exchange_strong(oldValue, 0)) { - values[0] = 0; + if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::first] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -192,8 +198,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[1].load(std::memory_order_relaxed) == key && keys[1].compare_exchange_strong(oldValue, 0)) { - values[1] = 0; + if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::second] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -201,8 +207,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[2].load(std::memory_order_relaxed) == key && keys[2].compare_exchange_strong(oldValue, 0)) { - values[2] = 0; + if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::third] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -216,12 +222,12 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldValue = key; - if (keys[0].load(std::memory_order_relaxed) == key && keys[0].compare_exchange_strong(oldValue, 0)) { - if (HM_UNLIKELY(beforeRemoveFunc(values[0]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::first]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - values[0] = 0; + values[BucketIdx::first] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -229,12 +235,12 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[1].load(std::memory_order_relaxed) == key && keys[1].compare_exchange_strong(oldValue, 0)) { - if (HM_UNLIKELY(beforeRemoveFunc(values[1]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::second]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - values[1] = 0; + values[BucketIdx::second] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -242,12 +248,12 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[2].load(std::memory_order_relaxed) == key && keys[2].compare_exchange_strong(oldValue, 0)) { - if (HM_UNLIKELY(beforeRemoveFunc(values[2]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::third]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - values[2] = 0; + values[BucketIdx::third] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -681,7 +687,7 @@ public: for (uint32_t j = 0; j < mBucketCount; j++) { auto buck = &mSubMap[j]; while (buck) { - for (int k = 0; k < 3; k++) { + for (size_t k = 0; k < K_KVNUMINBUCKET; k++) { if (buck->keys[k] == 0) { continue; } diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h index 39dba5c2..da69fcb9 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h @@ -54,6 +54,8 @@ public: { delete validPos; delete evictPos; + validPos = nullptr; + evictPos = nullptr; MapperBase::UnInitialize(); } diff --git a/src/AccCTR/tests/ut/src/emb_cache_test.cpp b/src/AccCTR/tests/ut/src/emb_cache_test.cpp index 582f6e68..9534ef04 100644 --- a/src/AccCTR/tests/ut/src/emb_cache_test.cpp +++ b/src/AccCTR/tests/ut/src/emb_cache_test.cpp @@ -19,6 +19,9 @@ limitations under the License. #include "emb_cache_test.h" #include "common.h" +using namespace std; +using namespace ock::ctr; + FactoryPtr factory; EmbCacheManagerPtr embCache = nullptr; diff --git a/src/AccCTR/tests/ut/src/emb_cache_test.h b/src/AccCTR/tests/ut/src/emb_cache_test.h index e8e2837d..5c87237b 100644 --- a/src/AccCTR/tests/ut/src/emb_cache_test.h +++ b/src/AccCTR/tests/ut/src/emb_cache_test.h @@ -15,16 +15,16 @@ limitations under the License. #ifndef CTR_EMB_CACHE_TEST_H #define CTR_EMB_CACHE_TEST_H -#include #include #include #include -#include "factory.h" + #include "gtest/gtest.h" #include "gmock/gmock.h" -using namespace std; -using namespace ock::ctr; +#include "factory.h" +#include "embedding_cache.h" + class EmbCacheTest : public testing::Test { protected: @@ -38,11 +38,11 @@ protected: void TearDown() override; - static EmbCacheManagerPtr SimpleCreateTable(std::string tableName, uint32_t hostVocabSize, uint32_t embeddingSize, - uint32_t extEmbeddingSize, uint32_t devVocabSize, pair normalPara = { 0, 0.05 }, + static ock::ctr::EmbCacheManagerPtr SimpleCreateTable(std::string tableName, uint32_t hostVocabSize, uint32_t embeddingSize, + uint32_t extEmbeddingSize, uint32_t devVocabSize, std::pair normalPara = { 0, 0.05 }, float constPara = 0.233); - static EmbCacheManagerPtr ConstZeroCreateTable(std::string tableName, uint32_t hostVocabSize, + static ock::ctr::EmbCacheManagerPtr ConstZeroCreateTable(std::string tableName, uint32_t hostVocabSize, uint32_t embeddingSize, uint32_t extEmbeddingSize, uint32_t devVocabSize, uint64_t prefillBufferSize = 50000, uint8_t prefillThreadNum = 1); -- Gitee From 3b158c9f497528dece4a7465a40c874ce57fe9d7 Mon Sep 17 00:00:00 2001 From: yangzhen Date: Fri, 10 May 2024 17:52:50 +0800 Subject: [PATCH 3/7] cleancode --- src/AccCTR/src/CMakeLists.txt | 3 +- .../cache_manager/cache_manager.cpp | 3 +- src/AccCTR/src/embedding_cache/common.h | 97 +------ src/AccCTR/src/embedding_cache/limited_set.h | 119 +++++++++ .../offset_mapper/mapper_base.h | 238 ++++++++++-------- .../offset_mapper/offset_mapper.h | 108 +++++--- src/AccCTR/src/unique/unique_func.cpp | 4 + 7 files changed, 328 insertions(+), 244 deletions(-) create mode 100644 src/AccCTR/src/embedding_cache/limited_set.h diff --git a/src/AccCTR/src/CMakeLists.txt b/src/AccCTR/src/CMakeLists.txt index 5aaa168d..1f4d9269 100644 --- a/src/AccCTR/src/CMakeLists.txt +++ b/src/AccCTR/src/CMakeLists.txt @@ -23,10 +23,11 @@ set(OUTPUT ${PROJECT_SOURCE_DIR}/output) set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) -add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) # must set this option otherwise pybind will not find embCache symbol if (${BUILD_MODE} MATCHES "ut") add_compile_options(-ftest-coverage -fprofile-arcs) link_libraries(gcov) +else() + add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) # must set this option otherwise pybind will not find embCache symbol endif (${BUILD_MODE} MATCHES "ut") if (${BUILD_MODE} MATCHES "fuzz") diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index 991307fd..c80be2b6 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -96,8 +96,7 @@ int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::v if (checkRet != H_OK) { return checkRet; } - return offsetMappers[tableName].GetSwapPairsAndKey2Offset(keys, swapInKoPair.first, swapInKoPair.second, - swapOutKoPair.first, swapOutKoPair.second); + return offsetMappers[tableName].GetSwapPairsAndKey2Offset(keys, swapInKoPair, swapOutKoPair); } int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, diff --git a/src/AccCTR/src/embedding_cache/common.h b/src/AccCTR/src/embedding_cache/common.h index 33a975c3..78b6985b 100644 --- a/src/AccCTR/src/embedding_cache/common.h +++ b/src/AccCTR/src/embedding_cache/common.h @@ -15,9 +15,7 @@ limitations under the License. #ifndef MXREC_COMMON_H #define MXREC_COMMON_H -#include -#include -#include +#include "limited_set.h" #ifndef HM_UNLIKELY #define HM_UNLIKELY(x) __builtin_expect(!!(x), 0) @@ -29,99 +27,6 @@ limitations under the License. namespace EmbCache { -class LimitedSet { -public: - struct Node { - uint64_t value; - Node *prev, *next; - Node(uint64_t val = -1) : value(val), prev(nullptr), next(nullptr) {} - }; - - LimitedSet(uint64_t maxRange) : head(new Node(-1)), tail(new Node(-1)) - { - nodes.resize(maxRange); - for (auto &node : nodes) { - node = new Node(-1); - } - head->next = tail; - tail->prev = head; - } - - ~LimitedSet() - { - for (auto &node : nodes) { - delete node; - } - delete head; - delete tail; - } - - void insert(uint64_t value) - { - if (nodes[value]->value == value) { - return; - } - Node *node = nodes[value]; - node->value = value; - Node *next = head->next; - node->next = next; - node->prev = head; - head->next = node; - next->prev = node; - } - - void remove(uint64_t value) - { - if (nodes[value]->value != value) { - return; - } - Node *node = nodes[value]; - node->prev->next = node->next; - node->next->prev = node->prev; - node->value = -1; - } - - bool find(uint64_t value) - { - return nodes[value]->value == value; - } - - class Iterator { - public: - Iterator(Node *node) : current(node) {} - bool operator != (const Iterator &other) const - { - return current != other.current; - } - const uint64_t &operator*() const - { - return current->value; - } - Iterator &operator ++ () - { - current = current->next; - return *this; - } - - private: - Node *current; - }; - - Iterator begin() - { - return { head->next }; - } - - Iterator end() - { - return { tail }; - } - -private: - Node *head; - Node *tail; - std::vector nodes; -}; enum class FkvState { FKV_EXIST, diff --git a/src/AccCTR/src/embedding_cache/limited_set.h b/src/AccCTR/src/embedding_cache/limited_set.h new file mode 100644 index 00000000..cbf4e8bd --- /dev/null +++ b/src/AccCTR/src/embedding_cache/limited_set.h @@ -0,0 +1,119 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef MXREC_LIMITED_SET_H +#define MXREC_LIMITED_SET_H + +#include +#include + +namespace EmbCache { + +class LimitedSet { +public: + struct Node { + uint64_t value; + Node *prev, *next; + Node(uint64_t val = -1) : value(val), prev(nullptr), next(nullptr) {} + }; + + LimitedSet(uint64_t maxRange) : head(new Node(-1)), tail(new Node(-1)) + { + nodes.resize(maxRange); + for (auto &node : nodes) { + node = new Node(-1); + } + head->next = tail; + tail->prev = head; + } + + ~LimitedSet() + { + for (auto &node : nodes) { + delete node; + } + delete head; + delete tail; + } + + void insert(uint64_t value) + { + if (nodes[value]->value == value) { + return; + } + Node *node = nodes[value]; + node->value = value; + Node *next = head->next; + node->next = next; + node->prev = head; + head->next = node; + next->prev = node; + } + + void remove(uint64_t value) + { + if (nodes[value]->value != value) { + return; + } + Node *node = nodes[value]; + node->prev->next = node->next; + node->next->prev = node->prev; + node->value = -1; + } + + bool find(uint64_t value) + { + return nodes[value]->value == value; + } + + class Iterator { + public: + Iterator(Node *node) : current(node) {} + bool operator != (const Iterator &other) const + { + return current != other.current; + } + const uint64_t &operator*() const + { + return current->value; + } + Iterator &operator ++ () + { + current = current->next; + return *this; + } + + private: + Node *current; + }; + + Iterator begin() + { + return { head->next }; + } + + Iterator end() + { + return { tail }; + } + +private: + Node *head; + Node *tail; + std::vector nodes; +}; + + +} +#endif // MXREC_LIMITED_SET_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h index 6f4debb6..ee159341 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -100,7 +100,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldKey = 0; - if (keys[BucketIdx::first].load(std::memory_order_relaxed) == 0 && keys[BucketIdx::first].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::first].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::first].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { keys[BucketIdx::first] = 0; @@ -119,7 +120,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldKey = 0; - if (keys[BucketIdx::second].load(std::memory_order_relaxed) == 0 && keys[BucketIdx::second].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::second].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::second].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { keys[BucketIdx::second] = 0; @@ -138,7 +140,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldKey = 0; - if (keys[BucketIdx::third].load(std::memory_order_relaxed) == 0 && keys[BucketIdx::third].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::third].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::third].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { keys[BucketIdx::third] = 0; @@ -189,7 +192,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldValue = key; - if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { + if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && + keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { values[BucketIdx::first] = 0; return FkvState::FKV_EXIST; } @@ -198,7 +202,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { + if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && + keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { values[BucketIdx::second] = 0; return FkvState::FKV_EXIST; } @@ -207,7 +212,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { + if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && + keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { values[BucketIdx::third] = 0; return FkvState::FKV_EXIST; } @@ -222,7 +228,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldValue = key; - if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { + if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && + keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::first]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } @@ -235,7 +242,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { + if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && + keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::second]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } @@ -248,7 +256,8 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { + if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && + keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::third]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } @@ -391,75 +400,7 @@ public: } // did not find, now do put. continue from the last bucket in find - - /* try 8192 times */ - for (uint16_t i = 0; i < 8192; i++) { - /* loop all buckets linked */ - while (buck != nullptr) { - /* if there is an entry to put, just break */ - buck->spinLock.Lock(); - FkvState putRet = buck->Put(key, value, beforePutFunc); - buck->spinLock.UnLock(); - if (putRet == FkvState::FKV_NOT_EXIST) { - current_size++; - return FkvState::FKV_NOT_EXIST; - } - - if (HM_UNLIKELY(putRet == FkvState::FKV_KEY_CONFLICT)) { - return FkvState::FKV_KEY_CONFLICT; - } - - if (HM_UNLIKELY(putRet == FkvState::FKV_BEFORE_PUT_FUNC_FAIL)) { - return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; - } - - if (HM_UNLIKELY(putRet == FkvState::FKV_NO_SPACE)) { - return FkvState::FKV_NO_SPACE; - } - - /* - * if no next bucket exist, just for break, - * else move to next bucket linked - */ - if (buck->next == nullptr) { - break; - } else { - buck = buck->next; - } - } - - /* - * if not put successfully in existing buckets, allocate a new one - * - * NOTES: just allocate memory, don't access new bucket in the spin lock scope, - * if access new bucket, which could trigger physical memory allocation which - * could trigger page fault, that is quite slow. In this case, spin lock - * could occupy too much CPU - */ - auto &lock = buck->spinLock; - lock.Lock(); - /* if other thread allocated new buck already, unlock and continue */ - if (buck->next != nullptr) { - buck = buck->next; - lock.UnLock(); - continue; - } - - /* firstly entered thread allocate new bucket */ - auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); - if (HM_UNLIKELY(newBuck == nullptr)) { - lock.UnLock(); - ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to allocate new bucket"); - return FkvState::FKV_FAIL; - } - /* link to current buck, set buck to new buck */ - buck->next = newBuck; - buck = newBuck; - - /* unlock */ - lock.UnLock(); - } - return FkvState::FKV_FAIL; + return PutKeyValue(key, value, buck, beforePutFunc); } FkvState Remove(uint64_t key) @@ -496,18 +437,18 @@ public: FkvState Remove(uint64_t key, const std::function &beforeRemoveFunc) { if (HM_UNLIKELY(key == 0)) { - if (zeroInside) { - if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { - auto ret = beforeRemoveFunc(zeroValue); - if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { - return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; - } - zeroValue = 0; - current_size--; + if (!zeroInside) { + return FkvState::FKV_NOT_EXIST; + } + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + auto ret = beforeRemoveFunc(zeroValue); + if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - return FkvState::FKV_EXIST; + zeroValue = 0; + current_size--; } - return FkvState::FKV_NOT_EXIST; + return FkvState::FKV_EXIST; } /* get bucket */ @@ -643,20 +584,20 @@ public: const std::function &beforeRemoveFunc) { if (HM_UNLIKELY(key == 0)) { - if (zeroInside) { - value = zeroValue; - if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { - auto ret = beforeRemoveFunc(zeroValue); - if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { - return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; - } - zeroValue = 0; - current_size--; + if (!zeroInside) { + return FkvState::FKV_NOT_EXIST; + } + value = zeroValue; + if (__sync_bool_compare_and_swap(&zeroInside, true, false)) { + auto ret = beforeRemoveFunc(zeroValue); + if (HM_UNLIKELY(ret == BeforeRemoveFuncState::BEFORE_FAIL)) { + return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - - return FkvState::FKV_EXIST; + zeroValue = 0; + current_size--; } - return FkvState::FKV_NOT_EXIST; + + return FkvState::FKV_EXIST; } /* get bucket */ auto buck = &(mSubMaps[key % gSubMapCount][key % mBucketCount]); @@ -686,15 +627,7 @@ public: for (auto &mSubMap : mSubMaps) { for (uint32_t j = 0; j < mBucketCount; j++) { auto buck = &mSubMap[j]; - while (buck) { - for (size_t k = 0; k < K_KVNUMINBUCKET; k++) { - if (buck->keys[k] == 0) { - continue; - } - kvVec.emplace_back(buck->keys[k].load(), buck->values[k]); - } - buck = buck->next; - } + ExtractKeyValInBuck(buck, kvVec); } } return kvVec; @@ -786,6 +719,93 @@ private: } } } + + FkvState PutKeyValue(uint64_t key, uint64_t& value, EmbCache::NetHashBucket *buck, + const std::function& beforePutFunc) + { + /* try 8192 times */ + for (uint16_t i = 0; i < 8192; i++) { + /* loop all buckets linked */ + while (buck != nullptr) { + /* if there is an entry to put, just break */ + buck->spinLock.Lock(); + FkvState putRet = buck->Put(key, value, beforePutFunc); + buck->spinLock.UnLock(); + if (putRet == FkvState::FKV_NOT_EXIST) { + current_size++; + return FkvState::FKV_NOT_EXIST; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_KEY_CONFLICT)) { + return FkvState::FKV_KEY_CONFLICT; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_BEFORE_PUT_FUNC_FAIL)) { + return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; + } + + if (HM_UNLIKELY(putRet == FkvState::FKV_NO_SPACE)) { + return FkvState::FKV_NO_SPACE; + } + + /* + * if no next bucket exist, just for break, + * else move to next bucket linked + */ + if (buck->next == nullptr) { + break; + } else { + buck = buck->next; + } + } + + /* + * if not put successfully in existing buckets, allocate a new one + * + * NOTES: just allocate memory, don't access new bucket in the spin lock scope, + * if access new bucket, which could trigger physical memory allocation which + * could trigger page fault, that is quite slow. In this case, spin lock + * could occupy too much CPU + */ + auto &lock = buck->spinLock; + lock.Lock(); + /* if other thread allocated new buck already, unlock and continue */ + if (buck->next != nullptr) { + buck = buck->next; + lock.UnLock(); + continue; + } + + /* firstly entered thread allocate new bucket */ + auto newBuck = static_cast(mOverflowEntryAlloc->Allocate(sizeof(NetHashBucket))); + if (HM_UNLIKELY(newBuck == nullptr)) { + lock.UnLock(); + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "Failed to allocate new bucket"); + return FkvState::FKV_FAIL; + } + /* link to current buck, set buck to new buck */ + buck->next = newBuck; + buck = newBuck; + + /* unlock */ + lock.UnLock(); + } + return FkvState::FKV_FAIL; + } + + void ExtractKeyValInBuck(EmbCache::NetHashBucket *buck, std::vector>& kvVec) + { + while (buck) { + for (size_t k = 0; k < K_KVNUMINBUCKET; k++) { + if (buck->keys[k] == 0) { + continue; + } + kvVec.emplace_back(buck->keys[k].load(), buck->values[k]); + } + buck = buck->next; + } + } + }; } #endif // MXREC_MAPPER_BASE_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h index da69fcb9..552f5924 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h @@ -85,34 +85,62 @@ public: return maxCacheSize - useLength + evictSize; } - int GetSwapPairsAndKey2Offset(std::vector &keys, std::vector &swapInKeys, - std::vector &swapInPos, std::vector &swapOutKeys, std::vector &swapOutPos) + int GetSwapPairsAndKey2Offset(std::vector &keys, KeyOffsetPair &swapInKoPair, KeyOffsetPair &swapOutKoPair) { - std::vector swapInKeysID; + std::vector swapInKeysID = FilterKeys(keys, swapInKoPair); - for (uint64_t i = 0; i < keys.size(); i++) { - // Invalid key 不考虑 - if (HM_UNLIKELY(keys[i] == static_cast(INVALID_KEY))) { + uint64_t swapInCnt = 0; + int ret = FindInUsedPos(keys, swapInCnt, swapInKeysID, swapInKoPair, swapOutKoPair); + if (ret != ock::ctr::H_OK) { + return ret; + } + + // 剩下的Key从om中分配位置 + ret = FindInOffsetMapper(keys, swapInKoPair, swapInCnt, swapInKeysID); + if (ret != ock::ctr::H_OK) { + return ret; + } + + // 上个batch中的pos可被换出,加入validPos中 + for (uint64_t pos : lastBatchPos) { + if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { continue; } - // 在HBM中的key, 原地替换为pos后从validPos中移除 - // 不在HBM中的key,加入swapInKeys,并记录在keys中的下标,用于后续key->offset - if (Find(keys[i], keys[i])) { - validPos->remove(keys[i]); - } else { - swapInKeys.push_back(keys[i]); - swapInKeysID.push_back(i); + validPos->insert(pos); + } + + // 这里keys都已被替换成offset,这个batch使用的pos在下个batch不能被换出,移出validPos + for (uint64_t pos : keys) { + if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { + continue; } + validPos->remove(pos); + evictPos->remove(pos); } - swapInPos.resize(swapInKeys.size()); + lastBatchPos = keys; + return ock::ctr::H_OK; + } + + uint32_t GetUsage() + { + return useLength - evictSize; + } + + uint64_t FindInUsedPos(std::vector& keys, uint64_t& swapInCnt, std::vector& swapInKeysID, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) + { + std::vector &swapInKeys = swapInKoPair.first; + std::vector &swapInPos = swapInKoPair.second; + std::vector &swapOutKeys = swapOutKoPair.first; + std::vector &swapOutPos = swapOutKoPair.second; + // 换出量 = 换入量 - 剩余空间 uint64_t swapOutNum = swapInKeys.size() <= GetFreeLength() ? 0 : swapInKeys.size() - GetFreeLength(); swapOutKeys.resize(swapOutNum); swapOutPos.resize(swapOutNum); // 空间不足,前swapOutNum个Key从evictPos中拿可换出位置 - uint64_t swapInCnt = 0; for (uint64_t pos : *evictPos) { if (swapInCnt == swapInKeys.size()) { break; @@ -157,7 +185,15 @@ public: return ock::ctr::H_MAX_CACHESIZE_TOO_SMALL; } - // 剩下的Key从om中分配位置 + return ock::ctr::H_OK; + } + + int FindInOffsetMapper(std::vector& keys, KeyOffsetPair& swapInKoPair, uint64_t swapInCnt, + std::vector& swapInKeysID) + { + std::vector &swapInKeys = swapInKoPair.first; + std::vector &swapInPos = swapInKoPair.second; + for (uint64_t i = swapInCnt; i < swapInKeys.size(); i++) { swapInPos[i] = useLength++; if (HM_UNLIKELY(swapInPos[i] >= maxCacheSize)) { @@ -171,31 +207,31 @@ public: // key->offset keys[swapInKeysID[i]] = swapInPos[i]; } + return ock::ctr::H_OK; + } - // 上个batch中的pos可被换出,加入validPos中 - for (uint64_t pos : lastBatchPos) { - if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { - continue; - } - validPos->insert(pos); - } + std::vector FilterKeys(std::vector& keys, KeyOffsetPair &swapInKoPair) + { + std::vector &swapInKeys = swapInKoPair.first; + std::vector &swapInPos = swapInKoPair.second; - // 这里keys都已被替换成offset,这个batch使用的pos在下个batch不能被换出,移出validPos - for (uint64_t pos : keys) { - if (HM_UNLIKELY(pos == static_cast(INVALID_KEY))) { + std::vector swapInKeysID; + for (uint64_t i = 0; i < keys.size(); i++) { + // Invalid key 不考虑 + if (HM_UNLIKELY(keys[i] == static_cast(INVALID_KEY))) { continue; } - validPos->remove(pos); - evictPos->remove(pos); + // 在HBM中的key, 原地替换为pos后从validPos中移除 + // 不在HBM中的key,加入swapInKeys,并记录在keys中的下标,用于后续key->offset + if (Find(keys[i], keys[i])) { + validPos->remove(keys[i]); + } else { + swapInKeys.push_back(keys[i]); + swapInKeysID.push_back(i); + } } - - lastBatchPos = keys; - return ock::ctr::H_OK; - } - - uint32_t GetUsage() - { - return useLength - evictSize; + swapInPos.resize(swapInKeys.size()); + return swapInKeysID; } private: diff --git a/src/AccCTR/src/unique/unique_func.cpp b/src/AccCTR/src/unique/unique_func.cpp index 462d6f9e..2059bd89 100644 --- a/src/AccCTR/src/unique/unique_func.cpp +++ b/src/AccCTR/src/unique/unique_func.cpp @@ -195,6 +195,10 @@ int ShardedDedup::HandleIdCountFill(std::vector> &idCount, } else { uint32_t memSize = idCount.size() * sizeof(int32_t); auto rc = memcpy_s(uniqueOut.idCnt, memSize, (int32_t *)(idCount.data()), memSize); + if (rc != 0) { + return rc; + } + int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCnt]"); if (ret != 0) { return ret; -- Gitee From 3a68d9834800db5a31aa5e2732cb50b363e620df Mon Sep 17 00:00:00 2001 From: yangzhen_BIG Date: Fri, 10 May 2024 19:22:31 +0800 Subject: [PATCH 4/7] cleancode --- src/AccCTR/src/embedding_cache/limited_set.h | 1 - src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h | 1 - src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h | 3 ++- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/limited_set.h b/src/AccCTR/src/embedding_cache/limited_set.h index cbf4e8bd..036a6477 100644 --- a/src/AccCTR/src/embedding_cache/limited_set.h +++ b/src/AccCTR/src/embedding_cache/limited_set.h @@ -114,6 +114,5 @@ private: std::vector nodes; }; - } #endif // MXREC_LIMITED_SET_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h index ee159341..f76a6252 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -805,7 +805,6 @@ private: buck = buck->next; } } - }; } #endif // MXREC_MAPPER_BASE_H diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h index 552f5924..80170989 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h @@ -85,7 +85,8 @@ public: return maxCacheSize - useLength + evictSize; } - int GetSwapPairsAndKey2Offset(std::vector &keys, KeyOffsetPair &swapInKoPair, KeyOffsetPair &swapOutKoPair) + int GetSwapPairsAndKey2Offset(std::vector& keys, KeyOffsetPair& swapInKoPair, + KeyOffsetPair& swapOutKoPair) { std::vector swapInKeysID = FilterKeys(keys, swapInKoPair); -- Gitee From 6efe9793a0f1c25ad281b20f12d671ef11df325e Mon Sep 17 00:00:00 2001 From: yangzhen_BIG Date: Mon, 13 May 2024 10:55:30 +0800 Subject: [PATCH 5/7] cleancode and fix issue --- .../cache_manager/cache_manager.cpp | 79 ++++++++++--------- .../cache_manager/cache_manager.h | 50 ++++++------ .../embedding_local_table/emb_local_table.cpp | 18 +++-- .../truncated_normal_initializer.cpp | 8 ++ .../offset_mapper/mapper_base.h | 2 +- src/AccCTR/src/unique/unique_func.cpp | 3 + 6 files changed, 92 insertions(+), 68 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index c80be2b6..129ee51c 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -12,10 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "cache_manager.h" + #include #include "external_logger.h" -#include "cache_manager.h" using namespace EmbCache; using namespace ock; @@ -23,9 +24,9 @@ using namespace ock::ctr; int64_t EmbCache::INVALID_KEY = -1; -int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo &embCacheInfo, - const std::vector &initializerInfos, int64_t invalidKey, uint64_t prefillBufferSize, - uint32_t refillThreadNum) +int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, + const std::vector& initializerInfos, int64_t invalidKey, + uint64_t prefillBufferSize, uint32_t refillThreadNum) { int checkTableNameRet = CheckCreateTableName(embCacheInfo.tableName); if (checkTableNameRet != H_OK) { @@ -52,7 +53,7 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo &embCacheInfo, if (embCacheInfo.extEmbeddingSize % embCacheInfo.embeddingSize != 0) { ExternalLogger::PrintLog(LogLevel::ERROR, "extEmbeddingSize = embeddingSize + optimizerSize, " - "which is divisible by embeddingSize"); + "which is divisible by embeddingSize"); return H_EXT_EMBEDDING_SIZE_INVALID; } @@ -71,26 +72,27 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo &embCacheInfo, uint32_t reserve = embCacheInfo.vocabSize / VOCAB_CACHE_RATIO; if (!offsetMappers[embCacheInfo.tableName].Initialize(reserve, embCacheInfo.maxCacheSize)) { + offsetMappers[embCacheInfo.tableName].UnInitialize(); offsetMappers.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; } - EmbPoolParam embPoolParam{ prefillBufferSize, refillThreadNum }; + EmbPoolParam embPoolParam{prefillBufferSize, refillThreadNum}; if (!embTables[embCacheInfo.tableName].Initialize(embCacheInfo.extEmbeddingSize, embCacheInfo.vocabSize, reserve, - initializerInfos, embPoolParam)) { + initializerInfos, embPoolParam)) { offsetMappers.erase(embCacheInfo.tableName); embTables.erase(embCacheInfo.tableName); return H_MEMORY_ALLOC_ERROR; } - embCacheInfos.insert({ embCacheInfo.tableName, embCacheInfo }); + embCacheInfos.insert({embCacheInfo.tableName, embCacheInfo}); INVALID_KEY = invalidKey; return H_OK; } -int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::vector &keys, - KeyOffsetPair &swapInKoPair, KeyOffsetPair &swapOutKoPair) +int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::vector& keys, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) { int checkRet = CheckGetSwapPairsAndKey2Offset(tableName, swapInKoPair, swapOutKoPair); if (checkRet != H_OK) { @@ -99,8 +101,8 @@ int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::v return offsetMappers[tableName].GetSwapPairsAndKey2Offset(keys, swapInKoPair, swapOutKoPair); } -int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, - uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -123,8 +125,8 @@ int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vecto return embTables[tableName].Gather(reinterpret_cast(embAddr), keys, threadNum); } -int EmbCacheManagerImpl::EmbeddingLookupAddrs(std::string tableName, const std::vector &keys, - std::vector &addrs, uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingLookupAddrs(std::string tableName, const std::vector& keys, + std::vector& addrs, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -142,10 +144,9 @@ int EmbCacheManagerImpl::EmbeddingLookupAddrs(std::string tableName, const std:: return embTables[tableName].GatherAddrs(keys, addrs, threadNum); } - // 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 -int EmbCacheManagerImpl::EmbeddingLookupAndRemove(std::string tableName, const std::vector &keys, - float *embAddr, uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingLookupAndRemove(std::string tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -168,15 +169,15 @@ int EmbCacheManagerImpl::EmbeddingLookupAndRemove(std::string tableName, const s return embTables[tableName].GatherAndRemove(reinterpret_cast(embAddr), keys, threadNum); } -int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vector &keys, float *embAddr, - uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { return checkTableNameRet; } - if (!CheckValidThreadNum(threadNum)) { // 检查thread是否小于核数 + if (!CheckValidThreadNum(threadNum)) { // 检查thread是否小于核数 return H_THREAD_NUM_ERROR; } @@ -184,7 +185,7 @@ int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vecto return H_OK; } - if (embAddr == nullptr) { // 检查embAddr是不是空指针 + if (embAddr == nullptr) { // 检查embAddr是不是空指针 ExternalLogger::PrintLog(LogLevel::ERROR, "embAddr is nullptr"); return H_ADDRESS_NULL; } @@ -192,13 +193,17 @@ int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vecto return embTables[tableName].Scatter(reinterpret_cast(embAddr), keys, threadNum); } -int EmbCacheManagerImpl::EmbeddingRemove(std::string tableName, const std::vector &keys, uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingRemove(std::string tableName, const std::vector& keys, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { return checkTableNameRet; } + if (!CheckValidThreadNum(threadNum)) { // 检查thread是否小于核数 + return H_THREAD_NUM_ERROR; + } + if (keys.empty()) { return H_OK; } @@ -206,7 +211,7 @@ int EmbCacheManagerImpl::EmbeddingRemove(std::string tableName, const std::vecto return embTables[tableName].RemoveByKeys(keys, threadNum); } -int EmbCacheManagerImpl::RemoveEmbsByKeys(std::string tableName, const std::vector &keys) +int EmbCacheManagerImpl::RemoveEmbsByKeys(std::string tableName, const std::vector& keys) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -225,32 +230,32 @@ int EmbCacheManagerImpl::RemoveEmbsByKeys(std::string tableName, const std::vect return H_OK; } -int EmbCacheManagerImpl::GetEmbTableNames(std::vector &allTableNames) +int EmbCacheManagerImpl::GetEmbTableNames(std::vector& allTableNames) { if (!allTableNames.empty()) { ExternalLogger::PrintLog(LogLevel::ERROR, "allTableNames should be empty"); return H_ARG_NOT_EMPTY; } allTableNames.reserve(embTables.size()); - for (auto &embTable : embTables) { + for (auto& embTable : embTables) { allTableNames.emplace_back(embTable.first); } return H_OK; } int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(std::string tableName, - std::vector> &koVec) + std::vector>& koVec) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { return checkTableNameRet; } - OffsetMapper &om = offsetMappers[tableName]; + OffsetMapper& om = offsetMappers[tableName]; koVec = om.ExportSortedKVPairs(); return H_OK; } -int EmbCacheManagerImpl::Serialize(std::string tableName, std::vector &buffer) +int EmbCacheManagerImpl::Serialize(std::string tableName, std::vector& buffer) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -260,7 +265,7 @@ int EmbCacheManagerImpl::Serialize(std::string tableName, std::vector &buf return H_OK; } -int EmbCacheManagerImpl::Deserialize(std::string tableName, const std::vector &buffer) +int EmbCacheManagerImpl::Deserialize(std::string tableName, const std::vector& buffer) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -289,7 +294,7 @@ int EmbCacheManagerImpl::CheckValidTableName(std::string tableName) { if (tableName.size() > TABLE_NAME_MAX_SIZE) { ExternalLogger::PrintLog(LogLevel::ERROR, - "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); + "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); return H_TABLE_NAME_TOO_LONG; } auto om = offsetMappers.find(tableName); @@ -304,9 +309,9 @@ int EmbCacheManagerImpl::CheckValidTableName(std::string tableName) bool EmbCacheManagerImpl::CheckInitializer(uint32_t extEmbSize, std::vector initializerInfos) { std::sort(initializerInfos.begin(), initializerInfos.end(), - [](const auto &u, const auto &v) { return u.start < v.start; }); + [](const auto& u, const auto& v) { return u.start < v.start; }); uint32_t cur_pos = 0; - for (const auto &info : initializerInfos) { + for (const auto& info : initializerInfos) { if (info.initializer == nullptr) { ExternalLogger::PrintLog(LogLevel::ERROR, "initializer is nullptr"); return false; @@ -340,8 +345,8 @@ bool EmbCacheManagerImpl::CheckValidThreadNum(uint32_t threadNum) return true; } -int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair &swapInKoPair, - const KeyOffsetPair &swapOutKoPair) +int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair& swapInKoPair, + const KeyOffsetPair& swapOutKoPair) { if (!swapInKoPair.first.empty() || !swapInKoPair.second.empty() || !swapOutKoPair.first.empty() || !swapOutKoPair.second.empty()) { @@ -357,7 +362,7 @@ int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(std::string tableName, c return H_OK; } -int EmbCacheManagerImpl::CheckCreateTableName(const std::string &tableName) +int EmbCacheManagerImpl::CheckCreateTableName(const std::string& tableName) { if (tableName.empty()) { ExternalLogger::PrintLog(LogLevel::ERROR, "tableName can not be empty"); @@ -366,13 +371,13 @@ int EmbCacheManagerImpl::CheckCreateTableName(const std::string &tableName) if (tableName.size() > TABLE_NAME_MAX_SIZE) { ExternalLogger::PrintLog(LogLevel::ERROR, - "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); + "tableName size can not larger than " + std::to_string(TABLE_NAME_MAX_SIZE)); return H_TABLE_NAME_TOO_LONG; } return H_OK; } -uint32_t EmbCacheManagerImpl::GetUsage(const std::string &tableName) +uint32_t EmbCacheManagerImpl::GetUsage(const std::string& tableName) { return offsetMappers[tableName].GetUsage(); } diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h index 314d0572..d8c4ed9b 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h @@ -21,9 +21,9 @@ limitations under the License. #include #include "embedding_cache.h" +#include "embedding_local_table/emb_local_table.h" #include "error_code.h" #include "offset_mapper/offset_mapper.h" -#include "embedding_local_table/emb_local_table.h" namespace EmbCache { class EmbCacheManagerImpl : public EmbCacheManager { @@ -32,39 +32,39 @@ public: ~EmbCacheManagerImpl() override = default; - int CreateCacheForTable(const EmbCacheInfo &embCacheInfo, const std::vector &initializerInfos, - int64_t invalidKey, uint64_t prefillBufferSize, uint32_t refillThreadNum) override; + int CreateCacheForTable(const EmbCacheInfo& embCacheInfo, const std::vector& initializerInfos, + int64_t invalidKey, uint64_t prefillBufferSize, uint32_t refillThreadNum) override; - int GetSwapPairsAndKey2Offset(std::string tableName, std::vector &keys, KeyOffsetPair &swapInKoPair, - KeyOffsetPair &swapOutKoPair) override; + int GetSwapPairsAndKey2Offset(std::string tableName, std::vector& keys, KeyOffsetPair& swapInKoPair, + KeyOffsetPair& swapOutKoPair) override; - int EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, - uint32_t threadNum) override; + int EmbeddingLookup(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) override; - int EmbeddingLookupAddrs(std::string tableName, const std::vector &keys, std::vector &addrs, - uint32_t threadNum) override; + int EmbeddingLookupAddrs(std::string tableName, const std::vector& keys, std::vector& addrs, + uint32_t threadNum) override; - int EmbeddingUpdate(std::string tableName, const std::vector &keys, float *embAddr, - uint32_t threadNum) override; + int EmbeddingUpdate(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) override; - int EmbeddingRemove(std::string tableName, const std::vector &keys, uint32_t threadNum) override; + int EmbeddingRemove(std::string tableName, const std::vector& keys, uint32_t threadNum) override; - int EmbeddingLookupAndRemove(std::string tableName, const std::vector &keys, float *embAddr, - uint32_t threadNum) override; + int EmbeddingLookupAndRemove(std::string tableName, const std::vector& keys, float* embAddr, + uint32_t threadNum) override; - int RemoveEmbsByKeys(std::string tableName, const std::vector &keys) override; + int RemoveEmbsByKeys(std::string tableName, const std::vector& keys) override; - int GetEmbTableNames(std::vector &allTableNames) override; + int GetEmbTableNames(std::vector& allTableNames) override; - int ExportDeviceKeyOffsetPairs(std::string tableName, std::vector> &koVec) override; + int ExportDeviceKeyOffsetPairs(std::string tableName, std::vector>& koVec) override; - int Serialize(std::string tableName, std::vector &buffer) override; + int Serialize(std::string tableName, std::vector& buffer) override; - int Deserialize(std::string tableName, const std::vector &buffer) override; + int Deserialize(std::string tableName, const std::vector& buffer) override; void Destroy() override; - uint32_t GetUsage(const std::string &tableName) override; + uint32_t GetUsage(const std::string& tableName) override; private: std::map embCacheInfos; @@ -77,10 +77,10 @@ private: bool CheckValidThreadNum(uint32_t threadNum); - int CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair &swapInKoPair, - const KeyOffsetPair &swapOutKoPair); + int CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair& swapInKoPair, + const KeyOffsetPair& swapOutKoPair); - int CheckCreateTableName(const std::string &tableName); + int CheckCreateTableName(const std::string& tableName); }; -} -#endif // EMBEDDING_CACHE_MANAGER_H +} // namespace EmbCache +#endif // EMBEDDING_CACHE_MANAGER_H diff --git a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp index f0b050d6..84874142 100644 --- a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp +++ b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp @@ -60,7 +60,9 @@ int EmbLocalTable::RemoveByKeys(const std::vector &keys, uint32_t thre { if (threadNum == 1) { for (uint64_t key : keys) { - Remove(key); + if(!Remove(key)){ + return H_ERROR; + } } return H_OK; } @@ -76,16 +78,22 @@ int EmbLocalTable::RemoveByKeys(const std::vector &keys, uint32_t thre start[threadId] = (keys.size() / threadNum) * threadId + m; } - vector threads(threadNum); + vector> threads(threadNum); for (uint32_t threadId = 0; threadId < threadNum; threadId++) { - threads[threadId] = thread([&, threadId] { + threads[threadId] = std::async(std::launch::async, [&, threadId]() { for (uint64_t i = start[threadId]; i < start[threadId + 1]; i++) { - Remove(keys[i]); + if (!Remove(keys[i])) { + return H_ERROR; + } } + return H_OK; }); } for (auto &t : threads) { - t.join(); + auto res = t.get(); + if (res != H_OK) { + return res; + } } return H_OK; } diff --git a/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index c2441466..e605fc75 100644 --- a/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/AccCTR/src/embedding_cache/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -35,6 +35,7 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(uint32_t start, uint32_t } else { mean = initInfo.mean; } + if (initInfo.stddev > NORMAL_STDDEV_MAX) { ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal stddev param is greater than " + std::to_string(NORMAL_STDDEV_MAX) + ", and will use " + std::to_string(NORMAL_STDDEV_MAX) + "."); @@ -46,6 +47,13 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(uint32_t start, uint32_t } else { stddev = initInfo.stddev; } + + if (abs(stddev) < std::numeric_limits::epsilon()) { + ExternalLogger::PrintLog( + LogLevel::WARN, + "truncated normal stddev param is zero, initialization can be slow, suggest using constant initializer"); + } + if (initInfo.initK > INIT_K_MAX) { ExternalLogger::PrintLog(LogLevel::WARN, "truncated normal initK is greater than " + std::to_string(INIT_K_MAX) + ", and will use " + std::to_string(INIT_K_MAX) + "."); diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h index f76a6252..b10c3b8d 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -315,7 +315,7 @@ public: /* make physical page and set to zero */ auto ret = memset_s(tmp, sizeof(NetHashBucket) * bucketCount, 0, sizeof(NetHashBucket) * bucketCount); if (ret != 0) { - ock::ExternalLogger::ExternalLogger::PrintLog(ock::LogLevel::ERROR, + ock::ExternalLogger::PrintLog(ock::LogLevel::ERROR, "memset_s failed... size: " + std::to_string(sizeof(NetHashBucket) * bucketCount)); return false; } diff --git a/src/AccCTR/src/unique/unique_func.cpp b/src/AccCTR/src/unique/unique_func.cpp index 2059bd89..45ac768a 100644 --- a/src/AccCTR/src/unique/unique_func.cpp +++ b/src/AccCTR/src/unique/unique_func.cpp @@ -188,6 +188,9 @@ int ShardedDedup::HandleIdCountFill(std::vector> &idCount, if (conf.usePadding) { uint32_t memSize = idCount.size() * sizeof(int32_t); auto rc = memcpy_s(uniqueOut.idCntFill, memSize, (int32_t *)(idCount.data()), memSize); + if (rc != 0) { + return rc; + } int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCntFill]"); if (ret != 0) { return ret; -- Gitee From 31fec9bd41a628954052005d1548608fe80f54a7 Mon Sep 17 00:00:00 2001 From: yangzhen_BIG Date: Mon, 13 May 2024 11:10:19 +0800 Subject: [PATCH 6/7] cleancode --- .../embedding_local_table/emb_local_table.cpp | 2 +- .../offset_mapper/mapper_base.h | 90 +++++++++---------- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp index 84874142..af98c03c 100644 --- a/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp +++ b/src/AccCTR/src/embedding_cache/embedding_local_table/emb_local_table.cpp @@ -60,7 +60,7 @@ int EmbLocalTable::RemoveByKeys(const std::vector &keys, uint32_t thre { if (threadNum == 1) { for (uint64_t key : keys) { - if(!Remove(key)){ + if (!Remove(key)) { return H_ERROR; } } diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h index b10c3b8d..5eb48789 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -38,9 +38,9 @@ static constexpr size_t K_ALIGNMENT = 64; static constexpr size_t K_KVNUMINBUCKET = 3; enum BucketIdx { - first, - second, - third + FIRST, + SECOND, + THIRD }; class NetHeapAllocator { @@ -100,18 +100,18 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldKey = 0; - if (keys[BucketIdx::first].load(std::memory_order_relaxed) == 0 && - keys[BucketIdx::first].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::FIRST].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::FIRST].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { - keys[BucketIdx::first] = 0; + keys[BucketIdx::FIRST] = 0; return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; } if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { - keys[BucketIdx::first] = 0; + keys[BucketIdx::FIRST] = 0; return FkvState::FKV_NO_SPACE; } - values[BucketIdx::first] = value; + values[BucketIdx::FIRST] = value; return FkvState::FKV_NOT_EXIST; } @@ -120,18 +120,18 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldKey = 0; - if (keys[BucketIdx::second].load(std::memory_order_relaxed) == 0 && - keys[BucketIdx::second].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::SECOND].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::SECOND].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { - keys[BucketIdx::second] = 0; + keys[BucketIdx::SECOND] = 0; return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; } if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { - keys[BucketIdx::second] = 0; + keys[BucketIdx::SECOND] = 0; return FkvState::FKV_NO_SPACE; } - values[BucketIdx::second] = value; + values[BucketIdx::SECOND] = value; return FkvState::FKV_NOT_EXIST; } @@ -140,18 +140,18 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldKey = 0; - if (keys[BucketIdx::third].load(std::memory_order_relaxed) == 0 && - keys[BucketIdx::third].compare_exchange_strong(oldKey, key)) { + if (keys[BucketIdx::THIRD].load(std::memory_order_relaxed) == 0 && + keys[BucketIdx::THIRD].compare_exchange_strong(oldKey, key)) { BeforePutFuncState ret = beforePutFunc(); if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_FAIL)) { - keys[BucketIdx::third] = 0; + keys[BucketIdx::THIRD] = 0; return FkvState::FKV_BEFORE_PUT_FUNC_FAIL; } if (HM_UNLIKELY(ret == BeforePutFuncState::BEFORE_NO_SPACE)) { - keys[BucketIdx::third] = 0; + keys[BucketIdx::THIRD] = 0; return FkvState::FKV_NO_SPACE; } - values[BucketIdx::third] = value; + values[BucketIdx::THIRD] = value; return FkvState::FKV_NOT_EXIST; } @@ -170,18 +170,18 @@ struct alignas(K_ALIGNMENT)NetHashBucket { /* * expand the loop, instead of put them into a for/while loop for performance */ - if (key == keys[BucketIdx::first].load(std::memory_order_relaxed)) { - value = values[BucketIdx::first]; + if (key == keys[BucketIdx::FIRST].load(std::memory_order_relaxed)) { + value = values[BucketIdx::FIRST]; return true; } - if (key == keys[BucketIdx::second].load(std::memory_order_relaxed)) { - value = values[BucketIdx::second]; + if (key == keys[BucketIdx::SECOND].load(std::memory_order_relaxed)) { + value = values[BucketIdx::SECOND]; return true; } - if (key == keys[BucketIdx::third].load(std::memory_order_relaxed)) { - value = values[BucketIdx::third]; + if (key == keys[BucketIdx::THIRD].load(std::memory_order_relaxed)) { + value = values[BucketIdx::THIRD]; return true; } @@ -192,9 +192,9 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldValue = key; - if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && - keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { - values[BucketIdx::first] = 0; + if (keys[BucketIdx::FIRST].load(std::memory_order_relaxed) == key && + keys[BucketIdx::FIRST].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::FIRST] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -202,9 +202,9 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && - keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { - values[BucketIdx::second] = 0; + if (keys[BucketIdx::SECOND].load(std::memory_order_relaxed) == key && + keys[BucketIdx::SECOND].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::SECOND] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -212,9 +212,9 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && - keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { - values[BucketIdx::third] = 0; + if (keys[BucketIdx::THIRD].load(std::memory_order_relaxed) == key && + keys[BucketIdx::THIRD].compare_exchange_strong(oldValue, 0)) { + values[BucketIdx::THIRD] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -228,13 +228,13 @@ struct alignas(K_ALIGNMENT)NetHashBucket { { /* don't put them into loop, flat code is faster than loop */ uint64_t oldValue = key; - if (keys[BucketIdx::first].load(std::memory_order_relaxed) == key && - keys[BucketIdx::first].compare_exchange_strong(oldValue, 0)) { - if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::first]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + if (keys[BucketIdx::FIRST].load(std::memory_order_relaxed) == key && + keys[BucketIdx::FIRST].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::FIRST]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - values[BucketIdx::first] = 0; + values[BucketIdx::FIRST] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -242,13 +242,13 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::second].load(std::memory_order_relaxed) == key && - keys[BucketIdx::second].compare_exchange_strong(oldValue, 0)) { - if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::second]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + if (keys[BucketIdx::SECOND].load(std::memory_order_relaxed) == key && + keys[BucketIdx::SECOND].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::SECOND]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - values[BucketIdx::second] = 0; + values[BucketIdx::SECOND] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { @@ -256,13 +256,13 @@ struct alignas(K_ALIGNMENT)NetHashBucket { } oldValue = key; - if (keys[BucketIdx::third].load(std::memory_order_relaxed) == key && - keys[BucketIdx::third].compare_exchange_strong(oldValue, 0)) { - if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::third]) == BeforeRemoveFuncState::BEFORE_FAIL)) { + if (keys[BucketIdx::THIRD].load(std::memory_order_relaxed) == key && + keys[BucketIdx::THIRD].compare_exchange_strong(oldValue, 0)) { + if (HM_UNLIKELY(beforeRemoveFunc(values[BucketIdx::THIRD]) == BeforeRemoveFuncState::BEFORE_FAIL)) { return FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL; } - values[BucketIdx::third] = 0; + values[BucketIdx::THIRD] = 0; return FkvState::FKV_EXIST; } if (HM_UNLIKELY(oldValue == 0)) { -- Gitee From 2849cb04ef3b5a7eb9774811b9a6a5331ce13ad2 Mon Sep 17 00:00:00 2001 From: yangzhen_BIG Date: Mon, 13 May 2024 21:02:29 +0800 Subject: [PATCH 7/7] cleancode and fix issue --- .../src/common/util/external_threader.h | 4 +-- .../cache_manager/cache_manager.cpp | 33 ++++++++++--------- .../cache_manager/cache_manager.h | 29 ++++++++-------- .../offset_mapper/address_mapper.h | 12 ++++--- .../offset_mapper/mapper_base.h | 2 +- src/AccCTR/src/include/embedding_cache.h | 23 ++++++------- 6 files changed, 54 insertions(+), 49 deletions(-) diff --git a/src/AccCTR/src/common/util/external_threader.h b/src/AccCTR/src/common/util/external_threader.h index e6b723d7..5f7c500f 100644 --- a/src/AccCTR/src/common/util/external_threader.h +++ b/src/AccCTR/src/common/util/external_threader.h @@ -60,7 +60,7 @@ public: { std::lock_guard lock(taskMutex); - auto pt = std::make_unique>(f); + auto pt = std::make_unique>(std::forward(f)); auto fut = pt->get_future(); tasks.emplace(std::move(pt)); taskCv.notify_one(); @@ -72,7 +72,7 @@ private: std::queue>> tasks; std::mutex taskMutex; std::condition_variable taskCv; - volatile bool stop = false; + std::atomic stop = false; void WorkerThread() { diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index 129ee51c..aedb52b3 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -91,7 +91,7 @@ int EmbCacheManagerImpl::CreateCacheForTable(const EmbCacheInfo& embCacheInfo, return H_OK; } -int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::vector& keys, +int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(const std::string& tableName, std::vector& keys, KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) { int checkRet = CheckGetSwapPairsAndKey2Offset(tableName, swapInKoPair, swapOutKoPair); @@ -101,8 +101,8 @@ int EmbCacheManagerImpl::GetSwapPairsAndKey2Offset(std::string tableName, std::v return offsetMappers[tableName].GetSwapPairsAndKey2Offset(keys, swapInKoPair, swapOutKoPair); } -int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vector& keys, float* embAddr, - uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingLookup(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -125,7 +125,7 @@ int EmbCacheManagerImpl::EmbeddingLookup(std::string tableName, const std::vecto return embTables[tableName].Gather(reinterpret_cast(embAddr), keys, threadNum); } -int EmbCacheManagerImpl::EmbeddingLookupAddrs(std::string tableName, const std::vector& keys, +int EmbCacheManagerImpl::EmbeddingLookupAddrs(const std::string& tableName, const std::vector& keys, std::vector& addrs, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); @@ -145,7 +145,7 @@ int EmbCacheManagerImpl::EmbeddingLookupAddrs(std::string tableName, const std:: } // 如果多线程使用,严格保证传入的key线程间不会重复(unique key),否则可能出现未定义结果 -int EmbCacheManagerImpl::EmbeddingLookupAndRemove(std::string tableName, const std::vector& keys, +int EmbCacheManagerImpl::EmbeddingLookupAndRemove(const std::string& tableName, const std::vector& keys, float* embAddr, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); @@ -169,8 +169,8 @@ int EmbCacheManagerImpl::EmbeddingLookupAndRemove(std::string tableName, const s return embTables[tableName].GatherAndRemove(reinterpret_cast(embAddr), keys, threadNum); } -int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vector& keys, float* embAddr, - uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingUpdate(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -193,7 +193,8 @@ int EmbCacheManagerImpl::EmbeddingUpdate(std::string tableName, const std::vecto return embTables[tableName].Scatter(reinterpret_cast(embAddr), keys, threadNum); } -int EmbCacheManagerImpl::EmbeddingRemove(std::string tableName, const std::vector& keys, uint32_t threadNum) +int EmbCacheManagerImpl::EmbeddingRemove(const std::string& tableName, const std::vector& keys, + uint32_t threadNum) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -211,14 +212,14 @@ int EmbCacheManagerImpl::EmbeddingRemove(std::string tableName, const std::vecto return embTables[tableName].RemoveByKeys(keys, threadNum); } -int EmbCacheManagerImpl::RemoveEmbsByKeys(std::string tableName, const std::vector& keys) +int EmbCacheManagerImpl::RemoveEmbsByKeys(const std::string& tableName, const std::vector& keys) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { return checkTableNameRet; } - auto om = offsetMappers.find(tableName); - auto embTable = embTables.find(tableName); + const auto& om = offsetMappers.find(tableName); + const auto& embTable = embTables.find(tableName); for (auto key : keys) { if (key == static_cast(INVALID_KEY)) { ExternalLogger::PrintLog(LogLevel::WARN, "Try to evict invalid key"); @@ -243,7 +244,7 @@ int EmbCacheManagerImpl::GetEmbTableNames(std::vector& allTableName return H_OK; } -int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(std::string tableName, +int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(const std::string& tableName, std::vector>& koVec) { int checkTableNameRet = CheckValidTableName(tableName); @@ -255,7 +256,7 @@ int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(std::string tableName, return H_OK; } -int EmbCacheManagerImpl::Serialize(std::string tableName, std::vector& buffer) +int EmbCacheManagerImpl::Serialize(const std::string& tableName, std::vector& buffer) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -265,7 +266,7 @@ int EmbCacheManagerImpl::Serialize(std::string tableName, std::vector& buf return H_OK; } -int EmbCacheManagerImpl::Deserialize(std::string tableName, const std::vector& buffer) +int EmbCacheManagerImpl::Deserialize(const std::string& tableName, const std::vector& buffer) { int checkTableNameRet = CheckValidTableName(tableName); if (checkTableNameRet != H_OK) { @@ -290,7 +291,7 @@ void EmbCacheManagerImpl::Destroy() embTables.clear(); } -int EmbCacheManagerImpl::CheckValidTableName(std::string tableName) +int EmbCacheManagerImpl::CheckValidTableName(const std::string& tableName) { if (tableName.size() > TABLE_NAME_MAX_SIZE) { ExternalLogger::PrintLog(LogLevel::ERROR, @@ -345,7 +346,7 @@ bool EmbCacheManagerImpl::CheckValidThreadNum(uint32_t threadNum) return true; } -int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair& swapInKoPair, +int EmbCacheManagerImpl::CheckGetSwapPairsAndKey2Offset(const std::string& tableName, const KeyOffsetPair& swapInKoPair, const KeyOffsetPair& swapOutKoPair) { if (!swapInKoPair.first.empty() || !swapInKoPair.second.empty() || !swapOutKoPair.first.empty() || diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h index d8c4ed9b..81499172 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h @@ -35,32 +35,33 @@ public: int CreateCacheForTable(const EmbCacheInfo& embCacheInfo, const std::vector& initializerInfos, int64_t invalidKey, uint64_t prefillBufferSize, uint32_t refillThreadNum) override; - int GetSwapPairsAndKey2Offset(std::string tableName, std::vector& keys, KeyOffsetPair& swapInKoPair, - KeyOffsetPair& swapOutKoPair) override; + int GetSwapPairsAndKey2Offset(const std::string& tableName, std::vector& keys, + KeyOffsetPair& swapInKoPair, KeyOffsetPair& swapOutKoPair) override; - int EmbeddingLookup(std::string tableName, const std::vector& keys, float* embAddr, + int EmbeddingLookup(const std::string& tableName, const std::vector& keys, float* embAddr, uint32_t threadNum) override; - int EmbeddingLookupAddrs(std::string tableName, const std::vector& keys, std::vector& addrs, - uint32_t threadNum) override; + int EmbeddingLookupAddrs(const std::string& tableName, const std::vector& keys, + std::vector& addrs, uint32_t threadNum) override; - int EmbeddingUpdate(std::string tableName, const std::vector& keys, float* embAddr, + int EmbeddingUpdate(const std::string& tableName, const std::vector& keys, float* embAddr, uint32_t threadNum) override; - int EmbeddingRemove(std::string tableName, const std::vector& keys, uint32_t threadNum) override; + int EmbeddingRemove(const std::string& tableName, const std::vector& keys, uint32_t threadNum) override; - int EmbeddingLookupAndRemove(std::string tableName, const std::vector& keys, float* embAddr, + int EmbeddingLookupAndRemove(const std::string& tableName, const std::vector& keys, float* embAddr, uint32_t threadNum) override; - int RemoveEmbsByKeys(std::string tableName, const std::vector& keys) override; + int RemoveEmbsByKeys(const std::string& tableName, const std::vector& keys) override; int GetEmbTableNames(std::vector& allTableNames) override; - int ExportDeviceKeyOffsetPairs(std::string tableName, std::vector>& koVec) override; + int ExportDeviceKeyOffsetPairs(const std::string& tableName, + std::vector>& koVec) override; - int Serialize(std::string tableName, std::vector& buffer) override; + int Serialize(const std::string& tableName, std::vector& buffer) override; - int Deserialize(std::string tableName, const std::vector& buffer) override; + int Deserialize(const std::string& tableName, const std::vector& buffer) override; void Destroy() override; @@ -71,13 +72,13 @@ private: std::map offsetMappers; std::map embTables; - int CheckValidTableName(std::string tableName); + int CheckValidTableName(const std::string& tableName); bool CheckInitializer(uint32_t extEmbSize, std::vector initializerInfos); bool CheckValidThreadNum(uint32_t threadNum); - int CheckGetSwapPairsAndKey2Offset(std::string tableName, const KeyOffsetPair& swapInKoPair, + int CheckGetSwapPairsAndKey2Offset(const std::string& tableName, const KeyOffsetPair& swapInKoPair, const KeyOffsetPair& swapOutKoPair); int CheckCreateTableName(const std::string& tableName); diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h index 8b6eefae..e88b2a3a 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/address_mapper.h @@ -220,13 +220,15 @@ private: { std::unique_lock lock(producerMutex); while (!stop) { - if (BufferBin.GetLength() < maxBufferSize && !full) { - Produce(); - } else if (!full) { - producerCv.wait(lock); - } else { + if (full) { fullCv.wait(lock); + continue; + } + if (BufferBin.GetLength() < maxBufferSize) { + Produce(); + continue; } + producerCv.wait(lock); } } }; diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h index 5eb48789..969845ee 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/mapper_base.h @@ -291,7 +291,7 @@ public: } /* get proper bucket count */ - uint32_t bucketCount = reserve < 128 ? 128 : reserve; + uint32_t bucketCount = std::max(reserve, uint32_t(128)); if (bucketCount > gPrimes[gPrimesCount - 1]) { bucketCount = gPrimes[gPrimesCount - 1]; } else { diff --git a/src/AccCTR/src/include/embedding_cache.h b/src/AccCTR/src/include/embedding_cache.h index 69a41136..1aa2248a 100644 --- a/src/AccCTR/src/include/embedding_cache.h +++ b/src/AccCTR/src/include/embedding_cache.h @@ -182,7 +182,7 @@ public: * @Param swapOutKoPair: 输出参数,需要换出的Key-offset pair * @Return errorCode */ - virtual int GetSwapPairsAndKey2Offset(std::string tableName, std::vector &keys, + virtual int GetSwapPairsAndKey2Offset(const std::string& tableName, std::vector &keys, KeyOffsetPair &swapInKoPair, KeyOffsetPair &swapOutKoPair) = 0; /* * @@ -193,7 +193,7 @@ public: * @Param threadNum: 线程数 * @Return errorCode */ - virtual int EmbeddingLookup(std::string tableName, const std::vector &keys, float *embAddr, + virtual int EmbeddingLookup(const std::string& tableName, const std::vector &keys, float *embAddr, uint32_t threadNum = 4) = 0; /* * @@ -204,7 +204,7 @@ public: * @Param threadNum: 线程数 * @Return errorCode */ - virtual int EmbeddingLookupAddrs(std::string tableName, const std::vector &keys, + virtual int EmbeddingLookupAddrs(const std::string& tableName, const std::vector &keys, std::vector &addrs, uint32_t threadNum = 4) = 0; /* * @@ -216,8 +216,8 @@ public: * @Param threadNum: 线程数 * @Return errorCode */ - virtual int EmbeddingLookupAndRemove(std::string tableName, const std::vector &keys, float *embAddr, - uint32_t threadNum = 4) = 0; + virtual int EmbeddingLookupAndRemove(const std::string& tableName, const std::vector& keys, + float* embAddr, uint32_t threadNum = 4) = 0; /* * * 更新Embedding @@ -227,7 +227,7 @@ public: * @Param threadNum: 线程数 * @Return errorCode */ - virtual int EmbeddingUpdate(std::string tableName, const std::vector &keys, float *embAddr, + virtual int EmbeddingUpdate(const std::string& tableName, const std::vector &keys, float *embAddr, uint32_t threadNum = 4) = 0; /* * @@ -236,7 +236,8 @@ public: * @Param keys: 待移除的keys * @Return errorCode */ - virtual int EmbeddingRemove(std::string tableName, const std::vector &keys, uint32_t threadNum = 4) = 0; + virtual int EmbeddingRemove(const std::string& tableName, const std::vector& keys, + uint32_t threadNum = 4) = 0; /* * * 将需要被淘汰的keys从offsetMapper的记录中移除,同时也在EmbLocalTable中移除,并将存储其embedding的内存位置记为可复用 @@ -244,7 +245,7 @@ public: * @Param keys: 待淘汰的keys * @Return errorCode */ - virtual int RemoveEmbsByKeys(std::string tableName, const std::vector &keys) = 0; + virtual int RemoveEmbsByKeys(const std::string& tableName, const std::vector &keys) = 0; /* * * 获取所有table names @@ -259,7 +260,7 @@ public: * koVec: 输出参数 * @Return errorCode */ - virtual int ExportDeviceKeyOffsetPairs(std::string tableName, + virtual int ExportDeviceKeyOffsetPairs(const std::string& tableName, std::vector> &koVec) = 0; /* * @@ -268,7 +269,7 @@ public: * @Param buffer: 输出参数,存储序列化之后的信息 * @Return errorCode */ - virtual int Serialize(std::string tableName, std::vector &buffer) = 0; + virtual int Serialize(const std::string& tableName, std::vector &buffer) = 0; /* * * 将当前table的序列化信息进行反序列化 @@ -276,7 +277,7 @@ public: * @Param buffer: 输入参数,将buffer中的内容进行反序列化 * @Return errorCode */ - virtual int Deserialize(std::string tableName, const std::vector &buffer) = 0; + virtual int Deserialize(const std::string& tableName, const std::vector &buffer) = 0; /* * * 析构所有embCache,释放内存 -- Gitee