From 31540d9944b6716f4deb18d8fba0cd21f909b722 Mon Sep 17 00:00:00 2001 From: xiangpx Date: Mon, 4 Aug 2025 14:55:38 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E7=BA=A7=E7=BC=93=E5=AD=983?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../csrc/emb_table/emb_mem_pool.cpp | 91 +++++ .../csrc/emb_table/emb_memory_pool.h | 106 ++++++ .../csrc/emb_table/emb_table.h | 349 ++++++++++++++++++ .../csrc/emb_table/initializer.h | 127 +++++++ 4 files changed, 673 insertions(+) create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_mem_pool.cpp create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_memory_pool.h create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_table.h create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/initializer.h diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_mem_pool.cpp b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_mem_pool.cpp new file mode 100644 index 00000000..4e9f071d --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_mem_pool.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "emb_memory_pool.h" + +#include + +#include "securec.h" + +#include "initializer.h" + +namespace Embcache { + +BeforePutFuncState EmbMemoryPool::GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry) +{ + for (uint32_t i = 0; i < maxRetry; i++) { + if (bufferBin_.pop(value)) { + GetEmbMemoryPoolThreadPool().enqueue([this] { Produce(); }); + return BeforePutFuncState::BEFORE_SUCCESS; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + LOG_ERROR("Failed to get new address for embedding, it is likely due to refill thread " + "memory allocation failure or max retry has been reached. " + "Please check for memory alloc error or increase refill thread num!"); + return BeforePutFuncState::BEFORE_FAIL; +} + +void EmbMemoryPool::GetValueToBeRecycled(uint64_t value) +{ + recycleBin_.push(value); +} + +bool EmbMemoryPool::GetNewAddr(uint64_t& newAddr) +{ + std::lock_guard lg(getAddrMutex_); + if (HM_UNLIKELY(currentMemoryUint_.leftCapacity <= 0)) { + /* need to expand memory */ + uint64_t maxSize = std::min(maxExpandSize_, totalLeftVocabSize_ * itemSize_); + uint64_t newSize = currentMemoryUint_.capacity + ? std::min(currentMemoryUint_.capacity * dynamicExpandRatio_, maxSize) + : itemSize_; + if (newSize == 0) { // 所有hostVocabSize均已分配 + return false; + } + auto newAddress = reinterpret_cast(malloc(newSize)); + if (newAddress == 0) { + LOG_ERROR("Refill thread allocate memory failed!"); + return false; + } + expandedMemory_.emplace_back(newAddress, newSize); + currentMemoryUint_.address = newAddress; + currentMemoryUint_.capacity = newSize; + currentMemoryUint_.leftCapacity = newSize; + totalLeftVocabSize_ -= newSize / itemSize_; + } + newAddr = currentMemoryUint_.address + currentMemoryUint_.capacity - currentMemoryUint_.leftCapacity; + currentMemoryUint_.leftCapacity -= itemSize_; + return true; +} + +void EmbMemoryPool::Produce() +{ + uint64_t newAddr; + if (!recycleBin_.pop(newAddr) && !GetNewAddr(newAddr)) { + return; + } + + // init embedding + if (embConfig_.initializerRandomPoolSize == -1) { + Initializer::InitEmbeddingWeights(reinterpret_cast(newAddr), embConfig_); + } else { + Initializer::InitEmbeddingWeightsLimitPool(reinterpret_cast(newAddr), embConfig_); + } + + // init optimizer + auto ret = memset_s(reinterpret_cast(newAddr) + embConfig_.embDim, + embConfig_.optimNum * embConfig_.embDim * sizeof(float), 0, + embConfig_.optimNum * embConfig_.embDim * sizeof(float)); + if (ret != EOK) { + throw std::runtime_error("memset_s failed when init optimizer data."); + } + + bufferBin_.push(newAddr); +} + +} // namespace Embcache \ No newline at end of file diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_memory_pool.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_memory_pool.h new file mode 100644 index 00000000..29e13315 --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_memory_pool.h @@ -0,0 +1,106 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef EMBEDDING_CACHE_EMB_MEMORY_POOL_H +#define EMBEDDING_CACHE_EMB_MEMORY_POOL_H + +#include +#include +#include +#include + +#include "common/common.h" +#include "utils/thread_pool.h" +#include "utils/safe_queue.h" +#include "utils/logger.h" + +namespace Embcache { + +using EmExpandMemUint = struct EmExpandMemoryUint { + uint64_t address = 0; + uint64_t capacity = 0; + uint64_t leftCapacity = 0; + + EmExpandMemoryUint() = default; + + EmExpandMemoryUint(uint64_t a, uint64_t c) : address(a), capacity(c), leftCapacity(c) {} +}; + +class EmbMemoryPool { +public: + EmbMemoryPool(const EmbConfig& embConfig, uint64_t bufferSize, uint64_t hostVocabSize) + : embConfig_(embConfig), + maxBufferSize_(bufferSize), + totalLeftVocabSize_(hostVocabSize) + { + itemSize_ = (embConfig.optimNum + 1) * embConfig.embDim * sizeof(float); + maxExpandSize_ = maxBufferSize_ * itemSize_; + char* poolSizeStr = getenv("EMB_MEMORY_POOL_SIZE"); + if (poolSizeStr) { + char* endptr = nullptr; + embMemoryPoolSize_ = strtoul(poolSizeStr, &endptr, 10); + if (endptr == poolSizeStr || *endptr != '\0') { + LOG_ERROR("env EMB_MEMORY_POOL_SIZE is not a valid number"); + throw std::runtime_error("env EMB_MEMORY_POOL_SIZE is not a valid number"); + } + if (embMemoryPoolSize_ == 0) { + LOG_ERROR("env EMB_MEMORY_POOL_SIZE = 0, it is invalid"); + throw std::runtime_error("env EMB_MEMORY_POOL_SIZE is invalid"); + } + } + LOG_INFO("EmbMemoryPool embMemoryPoolSize: {}", embMemoryPoolSize_); + for (uint64_t i = 0; i < embMemoryPoolSize_; i++) { + Produce(); + } + } + + EmbMemoryPool(const EmbMemoryPool& pool) = delete; + + EmbMemoryPool& operator=(const EmbMemoryPool& pool) = delete; + + EmbMemoryPool(EmbMemoryPool&& pool) = delete; + + EmbMemoryPool& operator=(EmbMemoryPool&& pool) = delete; + + ~EmbMemoryPool() + { + for (const auto& memUint : expandedMemory_) { + free(reinterpret_cast(memUint.address)); + } + } + + BeforePutFuncState GetNewValueToBeInserted(uint64_t& value, uint32_t maxRetry = 1000); + void GetValueToBeRecycled(uint64_t value); + +private: + bool GetNewAddr(uint64_t& newAddr); + void Produce(); + +private: + std::vector expandedMemory_; + EmbConfig embConfig_; + +private: + uint64_t maxBufferSize_; + uint64_t totalLeftVocabSize_; + + std::mutex getAddrMutex_; + + SafeQueue bufferBin_; + SafeQueue recycleBin_; + + EmExpandMemUint currentMemoryUint_{}; + uint64_t dynamicExpandRatio_ = 2; + + uint64_t maxExpandSize_ = 0; + uint64_t itemSize_; + + uint64_t embMemoryPoolSize_ = 102400; +}; + +} // namespace Embcache +#endif // EMBEDDING_CACHE_EMB_MEMORY_POOL_H diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_table.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_table.h new file mode 100644 index 00000000..95e9e16c --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/emb_table.h @@ -0,0 +1,349 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef EMBEDDING_CACHE_EMB_TABLE_H +#define EMBEDDING_CACHE_EMB_TABLE_H + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "securec.h" + +#include "common/common.h" +#include "common/constants.h" +#include "emb_memory_pool.h" +#include "hash_table/fast_hashmap.h" +#include "initializer.h" +#include "utils/logger.h" + +namespace Embcache { + +class EmbTable { +public: + explicit EmbTable(const EmbConfig& embConfig) + : config_(embConfig), + extEmbDim_((1 + embConfig.optimNum) * embConfig.embDim) + { + } + + virtual ~EmbTable() = default; + + virtual void FindOrInsert(const std::vector& keys, float* outEmbs, std::vector outOptims) = 0; + virtual void InsertOrAssign(const std::vector& keys, float* inEmbs, std::vector inOptims) = 0; + virtual void RemoveEmbedding(const std::vector& keys) = 0; + +protected: + EmbConfig config_; + int32_t extEmbDim_; // embDim + OptimNum * embDim +}; + +class EmbTableUnorderedMap : public EmbTable { +public: + explicit EmbTableUnorderedMap(const EmbConfig& embConfig) : EmbTable(embConfig) {} + + void FindOrInsert(const std::vector& keys, float* outEmbs, std::vector outOptims) override + { + std::lock_guard lk(mtx_); + auto embDim = config_.embDim; + auto optimNum = config_.optimNum; + if (outOptims.size() != optimNum) { + LOG_ERROR("outOptims size {} is not equal to optimNum {}", outOptims.size(), optimNum); + throw std::runtime_error("outOptims size is not equal to optimNum"); + } + for (uint64_t i = 0; i < keys.size(); i++) { + auto key = keys[i]; + auto it = table_.find(key); + if (it == table_.end()) { + auto res = table_.emplace(key, extEmbDim_); + it = res.first; + Initializer::InitEmbeddingWeights(it->second.data(), config_); + } + auto& emb = it->second; + + size_t size = embDim * sizeof(float); + if (outEmbs == nullptr) { + LOG_ERROR("outEmbs is nullptr"); + throw std::runtime_error("outEmbs is nullptr"); + } + auto rc = memcpy_s(outEmbs + i * embDim, size, emb.data(), size); + if (rc != 0) { + LOG_ERROR("memcpy_s emb to outEmbs[{}] failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s emb to outEmbs failed."); + } + + if (optimNum > 0) { + if (outOptims[0] == nullptr) { + LOG_ERROR("outOptims[0] is nullptr"); + throw std::runtime_error("outOptims[0] is nullptr"); + } + rc = memcpy_s(outOptims[0] + i * embDim, size, emb.data() + embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim1 to outOptims[{}][0] failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim1 to outOptims[0] failed."); + } + } + + if (optimNum > 1) { + if (outOptims[1] == nullptr) { + LOG_ERROR("outOptims[1] is nullptr"); + throw std::runtime_error("outOptims[1] is nullptr"); + } + rc = memcpy_s(outOptims[1] + i * embDim, size, emb.data() + optimNum * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim2 to outOptims[{}][1] failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim2 to outOptims[1] failed."); + } + } + } + } + + void InsertOrAssign(const std::vector& keys, float* inEmbs, std::vector inOptims) override + { + std::lock_guard lk(mtx_); + auto embDim = config_.embDim; + auto optimNum = config_.optimNum; + for (uint64_t i = 0; i < keys.size(); i++) { + auto key = keys[i]; + auto it = table_.find(key); + if (it == table_.end()) { + auto res = table_.emplace(key, extEmbDim_); + it = res.first; + } + auto& emb = it->second; + + size_t size = embDim * sizeof(float); + if (inEmbs == nullptr) { + LOG_ERROR("inEmbs is nullptr"); + throw std::runtime_error("inEmbs is nullptr"); + } + auto rc = memcpy_s(emb.data(), size, inEmbs + i * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s emb[{}] to table failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s emb to table failed."); + } + + if (optimNum > 0) { + if (inOptims[0] == nullptr) { + LOG_ERROR("inOptims[0] is nullptr"); + throw std::runtime_error("inOptims[0] is nullptr"); + } + rc = memcpy_s(emb.data() + embDim, size, inOptims[0] + i * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim1[{}] to table failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim1 to table failed."); + } + } + + if (optimNum > 1) { + if (inOptims[1] == nullptr) { + LOG_ERROR("inOptims[1] is nullptr"); + throw std::runtime_error("inOptims[1] is nullptr"); + } + rc = memcpy_s(emb.data() + optimNum * embDim, size, inOptims[1] + i * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim2[{}] to table failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim2 to table failed."); + } + } + } + } + + void RemoveEmbedding(const std::vector& keys) override + { + std::lock_guard lk(mtx_); + for (auto key : keys) { + table_.erase(key); + } + } + +private: + std::unordered_map> table_; + std::mutex mtx_; +}; + +class EmbTableFastHashMap : public EmbTable { +public: + explicit EmbTableFastHashMap(const EmbConfig& embConfig) : EmbTable(embConfig) + { + memPoolPtr_ = std::make_shared(embConfig, EmbMemPoolConfigConstants::bufferSize, + EmbMemPoolConfigConstants::hostVocabSize); + hostVocabSize_ = EmbMemPoolConfigConstants::hostVocabSize; + + fastHashMapPtr_ = std::make_shared(); + + uint64_t fastHashMapReserveBucketNum = FAST_HASHMAP_RESERVE_BUCKET_NUM; + char* fastHashMapReserveStr = getenv("FAST_HASHMAP_RESERVE_BUCKET_NUM"); + if (fastHashMapReserveStr) { + char* endptr = nullptr; + fastHashMapReserveBucketNum = strtoul(fastHashMapReserveStr, &endptr, 10); + if (endptr == fastHashMapReserveStr || *endptr != '\0') { + LOG_ERROR("env FAST_HASHMAP_RESERVE_BUCKET_NUM is not a valid number"); + throw std::runtime_error("env FAST_HASHMAP_RESERVE_BUCKET_NUM is not a valid number"); + } + } + fastHashMapPtr_->Init(fastHashMapReserveBucketNum); + LOG_INFO("FAST_HASHMAP_RESERVE_BUCKET_NUM: {}", fastHashMapReserveBucketNum); + } + + EmbTableFastHashMap(const EmbTableFastHashMap&) = delete; + EmbTableFastHashMap& operator=(const EmbTableFastHashMap&) = delete; + EmbTableFastHashMap(EmbTableFastHashMap&&) = delete; + EmbTableFastHashMap& operator=(EmbTableFastHashMap&&) = delete; + + ~EmbTableFastHashMap() override + { + fastHashMapPtr_->Destroy(); + } + + void FindOrInsert(const std::vector& keys, float* outEmbs, std::vector outOptims) override + { + auto embDim = config_.embDim; + auto optimNum = config_.optimNum; + at::parallel_for( + 0, keys.size(), std::ceil(keys.size() * 1.0 / at::get_num_threads()), [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + const auto key = keys[i]; + uint64_t addrValue = 0; + + FkvState ret = fastHashMapPtr_->FindOrInsert(key, addrValue, [&]() { + const uint64_t currentSize = fastHashMapPtr_->GetCurrentSize(); + if (HM_UNLIKELY(currentSize >= hostVocabSize_)) { + LOG_ERROR("No enough space at host, currentSize: {}, hostVocabSize: {}", currentSize, + hostVocabSize_); + return BeforePutFuncState::BEFORE_NO_SPACE; + } + return memPoolPtr_->GetNewValueToBeInserted(addrValue); + }); + if (ret == FkvState::FKV_FAIL) { + LOG_ERROR("fastHashMapPtr->FindOrInsert failed!"); + throw std::runtime_error("fastHashMapPtr->FindOrInsert failed!"); + } + if (ret == FkvState::FKV_BEFORE_PUT_FUNC_FAIL) { + LOG_ERROR("memory alloc failed!"); + throw std::runtime_error("memory alloc failed!"); + } + + size_t size = embDim * sizeof(float); + if (outEmbs == nullptr) { + LOG_ERROR("outEmbs is nullptr"); + throw std::runtime_error("outEmbs is nullptr"); + } + auto rc = memcpy_s(outEmbs + i * embDim, size, reinterpret_cast(addrValue), size); + if (rc != 0) { + LOG_ERROR("memcpy_s emb[{}] to outEmbs failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s emb to outEmbs failed."); + } + + if (optimNum > 0) { + if (outOptims[0] == nullptr) { + LOG_ERROR("outOptims[0] is nullptr"); + throw std::runtime_error("outOptims[0] is nullptr"); + } + rc = memcpy_s(outOptims[0] + i * embDim, size, + reinterpret_cast(addrValue) + embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim1[{}] to outOptims[0] failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim1 to outOptims[0] failed."); + } + } + + if (optimNum > 1) { + if (outOptims[1] == nullptr) { + LOG_ERROR("outOptims[1] is nullptr"); + throw std::runtime_error("outOptims[1] is nullptr"); + } + rc = memcpy_s(outOptims[1] + i * embDim, size, + reinterpret_cast(addrValue) + optimNum * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim2[{}] to outOptims[1] failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim2 to outOptims[1] failed."); + } + } + } + }); + } + + void InsertOrAssign(const std::vector& keys, float* inEmbs, std::vector inOptims) override + { + auto embDim = config_.embDim; + auto optimNum = config_.optimNum; + at::parallel_for( + 0, keys.size(), std::ceil(keys.size() * 1.0 / at::get_num_threads()), [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + const auto key = keys[i]; + uint64_t addrValue = 0; + + FkvState ret = fastHashMapPtr_->FindOrInsert(key, addrValue, [&]() { + const uint64_t currentSize = fastHashMapPtr_->GetCurrentSize(); + if (HM_UNLIKELY(currentSize >= hostVocabSize_)) { + LOG_ERROR("No enough space at host, currentSize: {}, hostVocabSize: {}", currentSize, + hostVocabSize_); + return BeforePutFuncState::BEFORE_NO_SPACE; + } + return memPoolPtr_->GetNewValueToBeInserted(addrValue); + }); + if (ret == FkvState::FKV_FAIL) { + LOG_ERROR("fastHashMapPtr->InsertOrAssign failed!"); + throw std::runtime_error("fastHashMapPtr->InsertOrAssign failed!"); + } + if (ret == FkvState::FKV_BEFORE_PUT_FUNC_FAIL) { + LOG_ERROR("memory alloc failed!"); + throw std::runtime_error("memory alloc failed!"); + } + + size_t size = embDim * sizeof(float); + auto rc = memcpy_s((float*)addrValue, size, inEmbs + i * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s emb[{}] to addrValue failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s emb to addrValue failed."); + } + if (optimNum > 0) { + rc = memcpy_s((float*)addrValue + embDim, size, inOptims[0] + i * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim1[{}] to addrValue failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim1 to addrValue failed."); + } + } + if (optimNum > 1) { + rc = memcpy_s((float*)addrValue + optimNum * embDim, size, inOptims[1] + i * embDim, size); + if (rc != 0) { + LOG_ERROR("memcpy_s optim2[{}] to addrValue failed. ret: {}", i, rc); + throw std::runtime_error("memcpy_s optim2 to addrValue failed."); + } + } + } + }); + } + + void RemoveEmbedding(const std::vector& keys) override + { + for (auto key : keys) { + FkvState ret = fastHashMapPtr_->Remove(key, [&](uint64_t value) { + memPoolPtr_->GetValueToBeRecycled(value); + return BeforeRemoveFuncState::BEFORE_SUCCESS; + }); + if (ret == FkvState::FKV_BEFORE_REMOVE_FUNC_FAIL) { + LOG_ERROR("remove embedding failed!"); + throw std::runtime_error("remove embedding failed!"); + } + } + } + +private: + std::shared_ptr memPoolPtr_; + std::shared_ptr fastHashMapPtr_; + uint64_t hostVocabSize_; +}; + +} // namespace Embcache +#endif // EMBEDDING_CACHE_EMB_TABLE_H \ No newline at end of file diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/initializer.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/initializer.h new file mode 100644 index 00000000..7c6cac81 --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/emb_table/initializer.h @@ -0,0 +1,127 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef EMBEDDING_CACHE_EMB_TABLE_INITIALIZER_H +#define EMBEDDING_CACHE_EMB_TABLE_INITIALIZER_H + +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" +#include "common/common.h" +#include "utils/logger.h" + +using RandomVPool = std::vector>; + +namespace Embcache { + +struct WeightInitParam { + float mean; + float stddev; + float minVal; + float maxVal; +}; + +class Initializer { +public: + static void GenUniform(float* array, size_t size, float minVal, float maxVal) + { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution distrib(minVal, maxVal); + std::generate(array, array + size, [&]() { return distrib(gen); }); + } + + static void GenLinear(float* array, size_t size, float minVal, float maxVal) + { + if (size == 0) { + return; + } + if (size == 1) { + array[0] = minVal; + return; + } + for (size_t i = 0; i < size; ++i) { + array[i] = minVal + 1.0 * i / (size - 1) * (maxVal - minVal); + } + } + + static void GenTruncatedNormal(float* array, size_t size, WeightInitParam weightParam, + unsigned int seed = std::random_device{}()) + { + if (array == nullptr || size == 0 || weightParam.stddev <= 0.0f || weightParam.minVal >= weightParam.maxVal) { + return; + } + + std::mt19937 gen(seed); + std::normal_distribution distrib(weightParam.mean, weightParam.stddev); + + std::generate(array, array + size, [&]() { + float val = distrib(gen); + while (val < weightParam.minVal || val > weightParam.maxVal) { + val = distrib(gen); + } + return val; + }); + } + + static void InitEmbeddingWeights(float* embeddingAddr, const EmbConfig& cfg) + { + if (cfg.initializerType == InitializerType::LINEAR) { + Initializer::GenLinear(embeddingAddr, cfg.embDim, cfg.weightInitMin, cfg.weightInitMax); + } else if (cfg.initializerType == InitializerType::TRUNCATED_NORMAL) { + WeightInitParam param = {cfg.weightInitMean, cfg.weightInitStddev, cfg.weightInitMin, cfg.weightInitMax}; + Initializer::GenTruncatedNormal(embeddingAddr, cfg.embDim, param); + } else { + Initializer::GenUniform(embeddingAddr, cfg.embDim, cfg.weightInitMin, cfg.weightInitMax); + } + } + + static void InitEmbeddingWeightsLimitPool(float* embeddingAddr, const EmbConfig& cfg) + { + static ska::flat_hash_map staticPoolMap; + static std::default_random_engine engine; + if (staticPoolMap.find(cfg.embDim) == staticPoolMap.end()) { + engine.seed(abs(cfg.seed)); + RandomVPool staticPool = + std::vector>(cfg.initializerRandomPoolSize, std::vector(cfg.embDim)); + for (int i = 0; i < cfg.initializerRandomPoolSize; i++) { + if (cfg.initializerType == InitializerType::LINEAR) { + Initializer::GenLinear(staticPool[i].data(), cfg.embDim, cfg.weightInitMin, cfg.weightInitMax); + } else if (cfg.initializerType == InitializerType::TRUNCATED_NORMAL) { + WeightInitParam param = {cfg.weightInitMean, cfg.weightInitStddev, cfg.weightInitMin, + cfg.weightInitMax}; + Initializer::GenTruncatedNormal(staticPool[i].data(), cfg.embDim, param); + } else { + Initializer::GenUniform(staticPool[i].data(), cfg.embDim, cfg.weightInitMin, cfg.weightInitMax); + } + } + staticPoolMap.emplace(cfg.embDim, staticPool); + } + RandomVPool& staticPool = staticPoolMap.find(cfg.embDim)->second; + std::uniform_int_distribution uDistribution(0, cfg.initializerRandomPoolSize - 1); + int randIndex = uDistribution(engine); + if (embeddingAddr == nullptr) { + LOG_ERROR("embeddingAddr is nullptr"); + throw std::runtime_error("embeddingAddr is nullptr"); + } + auto ret = memcpy_s(embeddingAddr, cfg.embDim * sizeof(float), staticPool[randIndex].data(), + cfg.embDim * sizeof(float)); + if (ret != 0) { + LOG_ERROR("memcpy_s failed when init optimizer data."); + throw std::runtime_error("memcpy_s failed when init optimizer data."); + } + } +}; + +} // namespace Embcache +#endif // EMBEDDING_CACHE_EMB_TABLE_INITIALIZER_H -- Gitee