diff --git a/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_host/split_embedding_codegen_forward_unweighted.cpp b/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_host/split_embedding_codegen_forward_unweighted.cpp index c95a094fd4ced1e2dbe429760f48daae5f67bfde..daa4ded2ac155b88c54d03d6aa4bdb4f8e89905e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_host/split_embedding_codegen_forward_unweighted.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_host/split_embedding_codegen_forward_unweighted.cpp @@ -215,6 +215,9 @@ public: this->Input("hash_indices") .ParamType(OPTIONAL).DataType({ge::DT_INT64}) .Format({ge::FORMAT_ND}).UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("offset_per_key") + .ParamType(OPTIONAL).DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}).UnknownShapeFormat({ge::FORMAT_ND}); this->Output("out") .ParamType(REQUIRED).DataType({ge::DT_FLOAT}) .Format({ge::FORMAT_ND}).UnknownShapeFormat({ge::FORMAT_ND}); diff --git a/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted.cpp b/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted.cpp index 74f4006aa7794b5d39ed31d708ac513f6cd724a0..ec578f90b8e5dc695d8874791e2ead35a50bc67a 100644 --- a/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted.cpp @@ -18,10 +18,12 @@ See the License for the specific language governing permissions and extern "C" __global__ __aicore__ void split_embedding_codegen_forward_unweighted( GM_ADDR devWeights, GM_ADDR uvmWeights, GM_ADDR lxuCacheWeights, GM_ADDR weightsPlacements, GM_ADDR weightsOffsets, GM_ADDR dOffsets, GM_ADDR indices, GM_ADDR offsets, GM_ADDR lxuCacheLocations, GM_ADDR hashIndices, + GM_ADDR indiceSizeCumsum, GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling) { SplitEmbeddingCodegenForwardUnweighted::Args args{ - devWeights, weightsPlacements, weightsOffsets, dOffsets, indices, offsets, hashIndices, out, tiling, workspace}; + devWeights, weightsPlacements, weightsOffsets, dOffsets, indices, offsets, hashIndices, indiceSizeCumsum, \ + out, tiling, workspace}; SplitEmbeddingCodegenForwardUnweighted::SplitEmbeddingCodegenForwardUnweightedKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted_kernel.h b/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted_kernel.h index 359a0e9f8c9fa3360ff7290514b03b2b3f5c75db..627c6d2855e4b0074a7edadb803e7d8670932205 100644 --- a/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/split_embedding_codegen_forward_unweighted/op_kernel/split_embedding_codegen_forward_unweighted_kernel.h @@ -42,6 +42,7 @@ struct Args { GM_ADDR indices; GM_ADDR offsets; GM_ADDR hashIndices; + GM_ADDR offsetPerKey; GM_ADDR out; GM_ADDR tiling; GM_ADDR workspace; @@ -59,8 +60,8 @@ public: __aicore__ inline SplitEmbeddingCodegenForwardUnweightedKernel(Args args) { GET_TILING_DATA(tilingData, args.tiling); - InitAddr(args); + // Shape devWeightsDim0 = tilingData.devWeightsDim0; weightsOffsetsDim0 = tilingData.weightsOffsetsDim0; @@ -107,6 +108,7 @@ public: offsetGT.SetGlobalBuffer((__gm__ int64_t*)offsets, offsetsDim0); dOffsetGT.SetGlobalBuffer((__gm__ int32_t*)dOffsets, dOffsetsDim0); weightOffsetGT.SetGlobalBuffer((__gm__ int64_t*)weightsOffsets, weightsOffsetsDim0); + offsetPerKeyGT.SetGlobalBuffer((__gm__ int64_t*)offsetPerKey, indicesDim0); outGT.SetGlobalBuffer((__gm__ float*)out, outDim0 * outDim1); @@ -129,6 +131,7 @@ public: indices = args.indices; offsets = args.offsets; hashIndices = args.hashIndices; + offsetPerKey = args.offsetPerKey; out = args.out; workspace = args.workspace; } @@ -176,7 +179,7 @@ public: queIn.EnQue(inputLt); } - __aicore__ inline void CopyOutEC(int64_t thisLen, int64_t startIndices) + __aicore__ inline void CopyOutNoPooling(int64_t thisLen, int64_t startIndices) { LocalTensor inputLt = queIn.DeQue(); LocalTensor outLt = queOut.AllocTensor(); @@ -193,7 +196,7 @@ public: queOut.FreeTensor(outLt); } - __aicore__ inline void CopyOutECPad(int64_t thisLen, int64_t startIndices) + __aicore__ inline void CopyOutNoPoolingPad(int64_t thisLen, int64_t startIndices) { LocalTensor inputLt = queIn.DeQue(); LocalTensor outLt = queOut.AllocTensor(); @@ -211,7 +214,7 @@ public: queOut.FreeTensor(outLt); } - __aicore__ inline void CopyOutEBC(int64_t outOffset, int64_t embedDim) + __aicore__ inline void CopyOutWithPooling(int64_t outOffset, int64_t embedDim) { auto outLt = queOut.DeQue(); SetAtomicAdd(); @@ -237,7 +240,7 @@ public: queOut.EnQue(outLt); } - __aicore__ inline void ProcessEBC(int64_t remain, int64_t startIndices, int64_t embedDim, + __aicore__ inline void ProcessWithPooling(int64_t remain, int64_t startIndices, int64_t embedDim, int64_t thisWeightOffset, int64_t outOffset) { float meanLen = static_cast(1) / static_cast(remain); @@ -253,14 +256,14 @@ public: // compute Pooling(meanLen, thisLen, embedDim); // copyout - CopyOutEBC(outOffset, embedDim); + CopyOutWithPooling(outOffset, embedDim); startIndices = startIndices + thisLen; thisLen = remain; } } - __aicore__ inline void ProcessEC(int64_t remain, int64_t startIndices, int64_t thisWeightOffset) + __aicore__ inline void ProcessNoPooling(int64_t remain, int64_t startIndices, int64_t thisWeightOffset) { int64_t thisLen = remain; while (remain > 0) { @@ -268,12 +271,11 @@ public: thisLen = indicesNumOneBlock; } remain -= thisLen; - CopyInNormal(startIndices, thisLen, maxD, thisWeightOffset); if (alignMaxD == maxD) { - CopyOutEC(thisLen, startIndices); + CopyOutNoPooling(thisLen, startIndices); } else { - CopyOutECPad(thisLen, startIndices); + CopyOutNoPoolingPad(thisLen, startIndices); } startIndices = startIndices + thisLen; @@ -281,15 +283,42 @@ public: } } - __aicore__ inline void Compute() + __aicore__ inline void Scheduler(const int64_t &totalLen, int64_t &offsetLen, int64_t &calcLen) + { + splitBaseLen = totalLen / GetBlockNum(); + tailSplitIndex = totalLen % GetBlockNum(); + if (GetBlockIdx() >= tailSplitIndex) { + calcLen = splitBaseLen; + offsetLen = tailSplitIndex * (splitBaseLen + 1) + (GetBlockIdx() - tailSplitIndex) * splitBaseLen; + } else { + calcLen = splitBaseLen + 1; + offsetLen = GetBlockIdx() * (splitBaseLen + 1); + } + } + + __aicore__ inline void ComputeNoPooling() + { + int64_t lastIndices = 0; + int64_t thisTableLen = 0; + for (int64_t i = 1; i <= weightsOffsetsDim0; i++) { + if (offsetPerKeyGT.GetValue(i) != lastIndices) { + Scheduler(offsetPerKeyGT.GetValue(i) - lastIndices, offsetOfThisCore, thisTableLen); + if (thisTableLen > 0) { + int64_t thisTableOffset = offsetOfThisCore + lastIndices; + int64_t thisWeightOffset = weightOffsetGT.GetValue(i - 1); + ProcessNoPooling(thisTableLen, thisTableOffset, thisWeightOffset); + } + lastIndices = offsetPerKeyGT.GetValue(i); + } + } + } + + __aicore__ inline void ComputeWithPooling() { if (lenOfThisCore == 0) { return; } - indicesNumOneBlock = blockLen / alignMaxD; - if (indicesNumOneBlock >= MAX_INDICS_ONE_BLOCK) { - indicesNumOneBlock = MAX_INDICS_ONE_BLOCK; - } + for (int64_t loop = 0; loop < lenOfThisCore; loop++) { int64_t i = (offsetOfThisCore + loop) / weightsOffsetsDim0; int64_t j = (offsetOfThisCore + loop) % weightsOffsetsDim0; @@ -305,17 +334,25 @@ public: // dataCopy In params int64_t tableIndex = thisOffsetIndex / batchs; int64_t thisWeightOffset = weightOffsetGT.GetValue(tableIndex); + // dataCopy Out params + int64_t outBatchInd = thisOffsetIndex % outDim0; + int64_t outEmbedOffset = dOffsetGT.GetValue(tableIndex); + int64_t outOffset = outBatchInd * outDim1 + outEmbedOffset; + int64_t embedDim = dOffsetGT.GetValue(tableIndex + 1) - dOffsetGT.GetValue(tableIndex); + ProcessWithPooling(thisLen, startIndices, embedDim, thisWeightOffset, outOffset); + } + } - if (poolMode == NONE_POOL) { - ProcessEC(thisLen, startIndices, thisWeightOffset); - } else { - // dataCopy Out params - int64_t outBatchInd = thisOffsetIndex % outDim0; - int64_t outEmbedOffset = dOffsetGT.GetValue(tableIndex); - int64_t outOffset = outBatchInd * outDim1 + outEmbedOffset; - int64_t embedDim = dOffsetGT.GetValue(tableIndex + 1) - dOffsetGT.GetValue(tableIndex); - ProcessEBC(thisLen, startIndices, embedDim, thisWeightOffset, outOffset); - } + __aicore__ inline void Compute() + { + indicesNumOneBlock = blockLen / alignMaxD; + if (indicesNumOneBlock >= MAX_INDICS_ONE_BLOCK) { + indicesNumOneBlock = MAX_INDICS_ONE_BLOCK; + } + if (poolMode == NONE_POOL) { + ComputeNoPooling(); + } else { + ComputeWithPooling(); } } @@ -328,6 +365,7 @@ private: GM_ADDR indices; GM_ADDR offsets; GM_ADDR hashIndices; + GM_ADDR offsetPerKey; GM_ADDR out; GM_ADDR workspace; @@ -365,6 +403,10 @@ private: int64_t lenOfThisCore; int64_t offsetOfThisCore; + // dynamic + int64_t blockEmbNum; + bool isDynamic; + // Tpipe TPipe pipe; TQue queIn; @@ -378,6 +420,7 @@ private: GlobalTensor offsetGT; GlobalTensor dOffsetGT; GlobalTensor weightOffsetGT; + GlobalTensor offsetPerKeyGT; }; } // namespace SplitEmbeddingCodegenForwardUnweighted #endif diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adagrad_unweighted_exact.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adagrad_unweighted_exact.cpp index 7065b06b915f5a25d9fac11004d403dc693bcd44..0e861d3c07f12ce57b54a33d8befb9bc57e0d6c4 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adagrad_unweighted_exact.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adagrad_unweighted_exact.cpp @@ -30,7 +30,7 @@ Tensor split_embedding_backward_codegen_adagrad_unweighted_exact_cuda( const int64_t info_B_mask_int64, const bool use_uniq_cache_locations, const bool use_homogeneous_placements, Tensor momentum1_dev, Tensor momentum1_uvm, Tensor momentum1_placements, Tensor momentum1_offsets, const Tensor& hash_indices, const Tensor& unique_ids, const Tensor& unique_offsets, const Tensor& unique_inverse, - double eps = 0, double learning_rate = 0); + const Tensor& offset_per_key, double eps = 0, double learning_rate = 0); class SplitLookupAdagrad : public torch::autograd::Function { public: @@ -62,14 +62,20 @@ public: auto info_B_num_bits = max_B_; auto info_B_mask = T; - + + // EC查表,计算每张表的indices个数 + int64_t batchs = (offsets.numel() - 1) / weights_offsets.numel(); + at::Tensor table_offsets = torch::arange(D_offsets.size(0), offsets.device()) * batchs; + at::Tensor offset_per_key = offsets.index_select(0, table_offsets.to(at::kLong)); + ctx->save_for_backward({dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, indices, offsets, indice_weights.value_or(Tensor()), feature_requires_grad.value_or(Tensor()), lxu_cache_locations, momentum1_dev, momentum1_uvm, momentum1_placements, momentum1_offsets, hash_indices.value_or(Tensor()), unique_ids.value_or(at::Tensor()), - unique_offsets.value_or(at::Tensor()), unique_inverse.value_or(at::Tensor())}); + unique_offsets.value_or(at::Tensor()), unique_inverse.value_or(at::Tensor()), + offset_per_key}); ctx->saved_data["max_D"] = max_D; ctx->saved_data["pooling_mode"] = pooling_mode; ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; @@ -94,7 +100,7 @@ public: return {embedding_codegen_forward_op.call( flatten_dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, total_D, max_D, indices, offsets, pooling_mode, lxu_cache_locations, uvm_cache_stats_, output_dtype, - is_experimental, hash_indices.value_or(Tensor()))}; + is_experimental, hash_indices.value_or(Tensor()), offset_per_key)}; } return {at::Tensor()}; } @@ -124,6 +130,7 @@ public: auto unique_ids = *savedItr++; auto unique_offsets = *savedItr++; auto unique_inverse = *savedItr++; + auto offset_per_key = *savedItr++; auto max_D = ctx->saved_data["max_D"].toSymInt(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); @@ -144,7 +151,6 @@ public: using torch::autograd::Variable; auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; - static auto embedding_codegen_unweighted_backward_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::split_embedding_backward_codegen_adagrad_unweighted_exact_cuda", "") @@ -155,8 +161,8 @@ public: max_D, hash_size_cumsum, total_hash_size_bits, indices, offsets, pooling_mode, lxu_cache_locations, BT_block_size, max_segment_length_per_warp, stochastic_rounding, info_B_num_bits, info_B_mask_int64, use_uniq_cache_locations_bwd, use_homogeneous_placements, momentum1_dev, momentum1_uvm, - momentum1_placements, momentum1_offsets, hash_indices, unique_ids, unique_offsets, unique_inverse, eps, - learning_rate); + momentum1_placements, momentum1_offsets, hash_indices, unique_ids, unique_offsets, unique_inverse, + offset_per_key, eps, learning_rate); return { Tensor(), // placeholder autograd tensor Variable(), // output_dtype @@ -191,6 +197,7 @@ public: Variable(), // unique_ids Variable(), // unique_offsets Variable(), // unique_inverse + Variable(), // offset_per_key Variable(), // eps Variable() // learning_rate }; @@ -242,7 +249,7 @@ at::Tensor split_embedding_backward_codegen_adagrad_unweighted_exact_npu( const int64_t info_B_mask_int64, const bool use_uniq_cache_locations, const bool use_homogeneous_placements, Tensor momentum1_dev, Tensor momentum1_uvm, Tensor momentum1_placements, Tensor momentum1_offsets, const Tensor& hash_indices, const at::Tensor& unique_ids, const at::Tensor& unique_offsets, - const at::Tensor& unique_inverse, double eps = 0, double learning_rate = 0) + const at::Tensor& unique_inverse, const at::Tensor& offset_per_key, double eps = 0, double learning_rate = 0) { const int64_t t_max_D = max_D.guard_int(__FILE__, __LINE__); @@ -252,11 +259,6 @@ at::Tensor split_embedding_backward_codegen_adagrad_unweighted_exact_npu( int64_t totalEmbed = unique_ids.numel() == 0 ? dev_weights.size(0) : unique_ids.numel() * t_max_D; auto output = at::empty({totalEmbed}, dev_weights.options().dtype(at::kFloat)); - // EC查表,计算每张表的indices个数 - int64_t batchs = (offsets.numel() - 1) / weights_offsets.numel(); - at::Tensor table_offsets = torch::arange(D_offsets.size(0), offsets.device()) * batchs; - at::Tensor indice_size_cumsum = offsets.index_select(0, table_offsets.to(at::kLong)); - int optim_type = static_cast(OptimizerType::ADAGRAD); const auto _unused = Tensor(); double beta = 0; @@ -266,7 +268,7 @@ at::Tensor split_embedding_backward_codegen_adagrad_unweighted_exact_npu( weights_placements, weights_offsets, D_offsets, hash_size_cumsum, indices, offsets, lxu_cache_locations, momentum1_dev, momentum1_uvm, momentum1_placements, momentum1_offsets, _unused, _unused, _unused, _unused, - hash_indices, unique_ids, unique_offsets, unique_inverse, indice_size_cumsum, t_max_D, + hash_indices, unique_ids, unique_offsets, unique_inverse, offset_per_key, t_max_D, total_hash_size_bits, pooling_mode, BT_block_size, max_segment_length_per_warp, stochastic_rounding, info_B_num_bits, info_B_mask_int64, use_uniq_cache_locations, use_homogeneous_placements, optim_type, eps, learning_rate, beta, beta, iter, output, momentum1_dev, _unused, dev_weights); @@ -358,6 +360,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) " Tensor unique_ids = None, " " Tensor unique_offsets = None, " " Tensor unique_inverse = None, " + " Tensor offset_per_key = None, " " float eps = 0, float learning_rate = 0 " ") -> Tensor"); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adam_unweighted_exact.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adam_unweighted_exact.cpp index 7d9d564c23987bf9a6febeb786ebe6b1f628bf60..cb350e995c74e062a8520a96540ce75cd17b15c2 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adam_unweighted_exact.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_adam_unweighted_exact.cpp @@ -55,6 +55,7 @@ Tensor split_embedding_backward_codegen_adam_unweighted_exact_cuda(const Tensor& const Tensor& unique_ids, const Tensor& unique_offsets, const Tensor& unique_inverse, + const Tensor& offset_per_key, double eps = 0, double learning_rate = 0, double beta1 = 0.9, @@ -121,13 +122,19 @@ public: auto info_B_num_bits = max_B_; auto info_B_mask = T; + // EC查表,计算每张表的indices个数 + int64_t batchs = (offsets.numel() - 1) / weights_offsets.numel(); + at::Tensor table_offsets = torch::arange(D_offsets.size(0), offsets.device()) * batchs; + at::Tensor offset_per_key = offsets.index_select(0, table_offsets.to(at::kLong)); + ctx->save_for_backward({dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, indices, offsets, indice_weights.value_or(Tensor()), feature_requires_grad.value_or(Tensor()), lxu_cache_locations, momentum1_dev, momentum1_uvm, momentum1_placements, momentum1_offsets, momentum2_dev, momentum2_uvm, momentum2_placements, momentum2_offsets, hash_indices.value_or(Tensor()), unique_ids.value_or(at::Tensor()), - unique_offsets.value_or(at::Tensor()), unique_inverse.value_or(at::Tensor())}); + unique_offsets.value_or(at::Tensor()), unique_inverse.value_or(at::Tensor()), + offset_per_key}); ctx->saved_data["max_D"] = max_D; ctx->saved_data["pooling_mode"] = pooling_mode; ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; @@ -156,7 +163,7 @@ public: return {embedding_codegen_forward_op.call( flatten_dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, total_D, max_D, indices, offsets, pooling_mode, lxu_cache_locations, uvm_cache_stats_, output_dtype, - is_experimental, hash_indices.value_or(Tensor()))}; + is_experimental, hash_indices.value_or(Tensor()), offset_per_key)}; } return {at::Tensor()}; } @@ -190,6 +197,7 @@ public: auto unique_ids = *savedItr++; auto unique_offsets = *savedItr++; auto unique_inverse = *savedItr++; + auto offset_per_key = *savedItr++; auto max_D = ctx->saved_data["max_D"].toSymInt(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); @@ -225,8 +233,8 @@ public: BT_block_size, max_segment_length_per_warp, stochastic_rounding, info_B_num_bits, info_B_mask_int64, use_uniq_cache_locations_bwd, use_homogeneous_placements, momentum1_dev, momentum1_uvm, momentum1_placements, momentum1_offsets, momentum2_dev, momentum2_uvm, momentum2_placements, - momentum2_offsets, hash_indices, unique_ids, unique_offsets, unique_inverse, eps, learning_rate, beta1, - beta2, iter); + momentum2_offsets, hash_indices, unique_ids, unique_offsets, unique_inverse, offset_per_key, eps, + learning_rate, beta1, beta2, iter); return { Tensor(), // placeholder autograd tensor Variable(), // output_dtype @@ -265,6 +273,7 @@ public: Variable(), // unique_ids Variable(), // unique_offsets Variable(), // unique_inverse + Variable(), // offset_per_key Variable(), // eps Variable(), // learning_rate Variable(), // beta1 @@ -375,6 +384,7 @@ at::Tensor split_embedding_backward_codegen_adam_unweighted_exact_npu(const Tens const at::Tensor& unique_ids, const at::Tensor& unique_offsets, const at::Tensor& unique_inverse, + const at::Tensor& offset_per_key, double eps = 0, double learning_rate = 0, double beta1 = 0, @@ -388,17 +398,12 @@ at::Tensor split_embedding_backward_codegen_adam_unweighted_exact_npu(const Tens int64_t totalEmbed = unique_ids.numel() == 0 ? dev_weights.size(0) : unique_ids.numel() * t_max_D; auto output = at::empty({totalEmbed}, dev_weights.options().dtype(at::kFloat)); - // EC查表,计算每张表的indices个数 - int64_t batchs = (offsets.numel() - 1) / weights_offsets.numel(); - at::Tensor table_offsets = torch::arange(D_offsets.size(0), offsets.device()) * batchs; - at::Tensor indice_size_cumsum = offsets.index_select(0, table_offsets.to(at::kLong)); - int optim_type = static_cast(OptimizerType::ADAM); EXEC_NPU_CMD(aclnnBackwardCodegenAdagradUnweightedExact, grad_output, dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, indices, offsets, lxu_cache_locations, momentum1_dev, momentum1_uvm, momentum1_placements, momentum1_offsets, momentum2_dev, momentum2_uvm, momentum2_placements, momentum2_offsets, hash_indices, unique_ids, - unique_offsets, unique_inverse, indice_size_cumsum, t_max_D, total_hash_size_bits, pooling_mode, + unique_offsets, unique_inverse, offset_per_key, t_max_D, total_hash_size_bits, pooling_mode, BT_block_size, max_segment_length_per_warp, stochastic_rounding, info_B_num_bits, info_B_mask_int64, use_uniq_cache_locations, use_homogeneous_placements, optim_type, eps, learning_rate, beta1, beta2, iter, output, momentum1_dev, momentum2_dev, dev_weights); @@ -495,6 +500,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) " Tensor unique_ids = None, " " Tensor unique_offsets = None, " " Tensor unique_inverse = None, " + " Tensor offset_per_key = None, " " float eps = 0, float learning_rate = 0, float beta1 = 0, float beta2 = 0, int iter = 0 " ") -> Tensor"); m.impl("split_embedding_backward_codegen_adam_unweighted_exact_cuda", diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_sgd_unweighted_exact.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_sgd_unweighted_exact.cpp index 8366351049ee9e1a1b190d29485c11ca40b4b97e..8f2dd983b5decf4327a04bfd158cee9aa3d6f2d4 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_sgd_unweighted_exact.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/backward_codegen_sgd_unweighted_exact.cpp @@ -47,6 +47,7 @@ Tensor split_embedding_backward_codegen_sgd_unweighted_exact_cuda(const Tensor& const Tensor& unique_ids, const Tensor& unique_offsets, const Tensor& unique_inverse, + const Tensor& offset_per_key, double learning_rate = 0); class SplitLookupSGD : public torch::autograd::Function { @@ -97,11 +98,17 @@ public: auto info_B_num_bits = max_B_; auto info_B_mask = T; + // EC查表,计算每张表的indices个数 + int64_t batchs = (offsets.numel() - 1) / weights_offsets.numel(); + at::Tensor table_offsets = torch::arange(D_offsets.size(0), offsets.device()) * batchs; + at::Tensor offset_per_key = offsets.index_select(0, table_offsets.to(at::kLong)); + ctx->save_for_backward({dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, indices, offsets, indice_weights.value_or(Tensor()), feature_requires_grad.value_or(Tensor()), lxu_cache_locations, hash_indices.value_or(Tensor()), unique_ids.value_or(at::Tensor()), - unique_offsets.value_or(at::Tensor()), unique_inverse.value_or(at::Tensor())}); + unique_offsets.value_or(at::Tensor()), unique_inverse.value_or(at::Tensor()), + offset_per_key}); ctx->saved_data["max_D"] = max_D; ctx->saved_data["pooling_mode"] = pooling_mode; ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; @@ -125,7 +132,7 @@ public: return {embedding_codegen_forward_op.call( flatten_dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, total_D, max_D, indices, offsets, pooling_mode, lxu_cache_locations, uvm_cache_stats_, output_dtype, - is_experimental, hash_indices.value_or(Tensor()))}; + is_experimental, hash_indices.value_or(Tensor()), offset_per_key)}; } return {at::Tensor()}; } @@ -151,6 +158,7 @@ public: auto unique_ids = *savedItr++; auto unique_offsets = *savedItr++; auto unique_inverse = *savedItr++; + auto offset_per_key = *savedItr++; auto max_D = ctx->saved_data["max_D"].toSymInt(); auto pooling_mode = ctx->saved_data["pooling_mode"].toInt(); auto total_hash_size_bits = ctx->saved_data["total_hash_size_bits"].toInt(); @@ -181,7 +189,7 @@ public: max_D, hash_size_cumsum, total_hash_size_bits, indices, offsets, pooling_mode, lxu_cache_locations, BT_block_size, max_segment_length_per_warp, stochastic_rounding, info_B_num_bits, info_B_mask_int64, use_uniq_cache_locations_bwd, use_homogeneous_placements, hash_indices, unique_ids, unique_offsets, - unique_inverse, learning_rate); + unique_inverse, offset_per_key, learning_rate); return { Tensor(), // placeholder autograd tensor Variable(), // output_dtype @@ -296,6 +304,7 @@ at::Tensor split_embedding_backward_codegen_sgd_unweighted_exact_npu(const Tenso const at::Tensor& unique_ids, const at::Tensor& unique_offsets, const at::Tensor& unique_inverse, + const at::Tensor& offset_per_key, double learning_rate = 0) { const int64_t t_max_D = max_D.guard_int(__FILE__, __LINE__); @@ -312,7 +321,7 @@ at::Tensor split_embedding_backward_codegen_sgd_unweighted_exact_npu(const Tenso aclnnBackwardCodegenAdagradUnweightedExact, grad_output, dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, hash_size_cumsum, indices, offsets, lxu_cache_locations, _unused, _unused, _unused, _unused, _unused, _unused, _unused, _unused, hash_indices, unique_ids, - unique_offsets, unique_inverse, _unused, t_max_D, total_hash_size_bits, pooling_mode, + unique_offsets, unique_inverse, offset_per_key, t_max_D, total_hash_size_bits, pooling_mode, BT_block_size, max_segment_length_per_warp, stochastic_rounding, info_B_num_bits, info_B_mask_int64, use_uniq_cache_locations, use_homogeneous_placements, optim_type, beta, learning_rate, beta, beta, iter, output, _unused, _unused, dev_weights); @@ -403,6 +412,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) " Tensor unique_ids = None, " " Tensor unique_offsets = None, " " Tensor unique_inverse = None, " + " Tensor offset_per_key = None, " " float learning_rate = 0 " ") -> Tensor"); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.cpp index 372ba2ed46335f8d3fef4fe143d5f996ad354348..4fac367ce7da15d63309a945479df79dfa0caa66 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.cpp @@ -36,7 +36,8 @@ at::Tensor split_embedding_codegen_forward_unweighted_npu(const at::Tensor& dev_ const at::Tensor& uvm_cache_stats, const int64_t output_dtype, const bool is_experimental, - const Tensor& hash_indices) + const Tensor& hash_indices, + const at::Tensor& offset_per_key) { const int64_t totalD = total_D.guard_int(__FILE__, __LINE__); const int64_t maxD = max_D.guard_int(__FILE__, __LINE__); @@ -65,7 +66,7 @@ at::Tensor split_embedding_codegen_forward_unweighted_npu(const at::Tensor& dev_ int64_t experimental = static_cast(is_experimental); EXEC_NPU_CMD(aclnnSplitEmbeddingCodegenForwardUnweighted, dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, D_offsets, indices, offsets, lxu_cache_locations, hash_indices, - totalD, maxD, pooling_mode, output_dtype, experimental, output); + offset_per_key, totalD, maxD, pooling_mode, output_dtype, experimental, output); return output; } @@ -89,7 +90,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) " Tensor uvm_cache_stats, " " int output_dtype, " " bool is_experimental, " - " Tensor hash_indices = None " + " Tensor hash_indices = None, " + " Tensor offset_per_key = None " ") -> Tensor"); m.impl("split_embedding_codegen_forward_unweighted_cuda", diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.h b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.h index fcc15ef0438ab2cb2f7d87139a4f2dae5c4b38c0..6180bf5b71b4a9dbfb6a5298a7b2d1f7d5a5e87e 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.h +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/split_embedding_codegen_forward_unweighted/split_embedding_codegen_forward_unweighted.h @@ -54,7 +54,8 @@ at::Tensor split_embedding_codegen_forward_unweighted_cuda(const at::Tensor& dev const at::Tensor& uvm_cache_stats, const int64_t output_dtype, const bool is_experimental, - const at::Tensor& hash_indices); + const at::Tensor& hash_indices, + const at::Tensor& offset_per_key); at::Tensor split_embedding_codegen_forward_unweighted_npu(const at::Tensor& dev_weights, const at::Tensor& uvm_weights, @@ -71,6 +72,7 @@ at::Tensor split_embedding_codegen_forward_unweighted_npu(const at::Tensor& dev_ const at::Tensor& uvm_cache_stats, const int64_t output_dtype, const bool is_experimental, - const at::Tensor& hash_indices); + const at::Tensor& hash_indices, + const at::Tensor& offset_per_key); }; // namespace fbgemm_npu_lookups #endif // MXREC_ADD_ONS_SPLIT_EMBEDDING_CODEGEN_FORWARD_UNWEIGHTED_H