diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f962c1332093abc2b19e87a844cc6614acc9fac9 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -0,0 +1,186 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# call kernel_declare need to make sure whether the target of input exists +function(kernel_declare TARGET_LIST) + foreach(kernel_path ${TARGET_LIST}) + file(READ ${kernel_path} kernel_impl) + # TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL + # NOTE(chenweihang): now we don't recommend to use digit in kernel name + string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") + if (NOT first_registry STREQUAL "") + # parse the first kernel name + string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") + string(REPLACE "PT_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name}") + string(REPLACE "," "" kernel_name "${kernel_name}") + string(REGEX REPLACE "[ \t\r\n]+" "" kernel_name "${kernel_name}") + # append kernel declare into declarations.h + # TODO(chenweihang): default declare ALL_LAYOUT for each kernel + if (${kernel_path} MATCHES "./cpu\/") + file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") + elseif (${kernel_path} MATCHES "./gpu\/") + file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, GPU, ALL_LAYOUT);\n") + elseif (${kernel_path} MATCHES "./xpu\/") + file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, XPU, ALL_LAYOUT);\n") + else () + # deal with device independent kernel, now we use CPU temporaary + file(APPEND ${kernel_declare_file} "PT_DECLARE_KERNEL(${kernel_name}, CPU, ALL_LAYOUT);\n") + endif() + endif() + endforeach() +endfunction() + +function(kernel_library TARGET) + set(common_srcs) + set(cpu_srcs) + set(gpu_srcs) + set(xpu_srcs) + # parse and save the deps kerenl targets + set(all_srcs) + set(kernel_deps) + + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(kernel_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + list(LENGTH kernel_library_SRCS kernel_library_SRCS_len) + # one kernel only match one impl file in each backend + if (${kernel_library_SRCS_len} EQUAL 0) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + list(APPEND common_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc) + list(APPEND cpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/cpu/${TARGET}.cc) + endif() + if (WITH_GPU OR WITH_ROCM) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu) + list(APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/gpu/${TARGET}.cu) + endif() + endif() + if (WITH_XPU) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc) + list(APPEND xpu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${TARGET}.cc) + endif() + endif() + else() + # TODO(chenweihang): impl compile by source later + endif() + + list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h) + list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h) + endif() + list(APPEND all_srcs ${common_srcs}) + list(APPEND all_srcs ${cpu_srcs}) + list(APPEND all_srcs ${gpu_srcs}) + list(APPEND all_srcs ${xpu_srcs}) + foreach(src ${all_srcs}) + file(READ ${src} target_content) + string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) + foreach(include_kernel ${include_kernels}) + string(REGEX REPLACE "#include \"paddle\/pten\/kernels\/" "" kernel_name ${include_kernel}) + string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name}) + list(APPEND kernel_deps ${kernel_name}) + endforeach() + endforeach() + list(REMOVE_DUPLICATES kernel_deps) + list(REMOVE_ITEM kernel_deps ${TARGET}) + + list(LENGTH common_srcs common_srcs_len) + list(LENGTH cpu_srcs cpu_srcs_len) + list(LENGTH gpu_srcs gpu_srcs_len) + list(LENGTH xpu_srcs xpu_srcs_len) + + if (${common_srcs_len} GREATER 0) + # If the kernel has a device independent public implementation, + # we will use this implementation and will not adopt the implementation + # under specific devices + if (WITH_GPU) + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + elseif (WITH_ROCM) + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + else() + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + else() + # If the kernel has a header file declaration, but no corresponding + # implementation can be found, this is not allowed + if (${cpu_srcs_len} EQUAL 0 AND ${gpu_srcs_len} EQUAL 0 AND + ${xpu_srcs_len} EQUAL 0) + message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") + else() + if (WITH_GPU) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + elseif (WITH_ROCM) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + else() + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() + endif() + endif() + endif() + + if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR + ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + # append target into PTEN_KERNELS property + get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) + set(pten_kernels ${pten_kernels} ${TARGET}) + set_property(GLOBAL PROPERTY PTEN_KERNELS ${pten_kernels}) + endif() + + # parse kernel name and auto generate kernel declaration + # here, we don't need to check WITH_XXX, because if not WITH_XXX, the + # xxx_srcs_len will be equal to 0 + if (${common_srcs_len} GREATER 0) + kernel_declare(${common_srcs}) + endif() + if (${cpu_srcs_len} GREATER 0) + kernel_declare(${cpu_srcs}) + endif() + if (${gpu_srcs_len} GREATER 0) + kernel_declare(${gpu_srcs}) + endif() + if (${xpu_srcs_len} GREATER 0) + kernel_declare(${xpu_srcs}) + endif() +endfunction() + +function(register_kernels) + set(options "") + set(oneValueArgs "") + set(multiValueArgs EXCLUDES DEPS) + cmake_parse_arguments(register_kernels "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + file(GLOB KERNELS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_kernel.h") + string(REPLACE ".h" "" KERNELS "${KERNELS}") + list(LENGTH register_kernels_DEPS register_kernels_DEPS_len) + + foreach(target ${KERNELS}) + list(FIND register_kernels_EXCLUDES ${target} _index) + if (${_index} EQUAL -1) + if (${register_kernels_DEPS_len} GREATER 0) + kernel_library(${target} DEPS ${register_kernels_DEPS}) + else() + kernel_library(${target}) + endif() + endif() + endforeach() +endfunction() diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c3e54290fd3da01ab7488f9828afff5e328f5c8e..dc4d1365093aa35293f8a9320fd164b4c29f6660 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext( // Otherwise,we will create new storage. for (size_t offset = 0; offset < outs_vector.size(); ++offset) { if (current_vector_size > start_idx + offset) { - experimental::ReMakePtenDenseTensorFromVar( - outs_vector[offset], out_def, + auto* buffer_tensor = pt_kernel_context_->MutableOutputAt(start_idx + - offset)); + offset); + if (buffer_tensor) { + experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset], + out_def, buffer_tensor); + } } else { pt_kernel_context_->EmplaceBackOutputWithoutSetRange( experimental::MakePtenTensorBaseFromVar(outs_vector[offset], out_def)); } } + + // Deal with the case that some outputs are NULL when run the kernel. + // For example : the outputs of matmul_grad are dx and dy, + // sometimes dx or dy may be NULL. + if (outs_vector.empty()) { + if (current_vector_size > start_idx) { + pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr}); + } else { + pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr}); + } + end_idx = start_idx + 1; + } + pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } @@ -2002,7 +2018,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const { range_pair.first, range_pair.second); for (size_t j = 0; j < pten_outs.size(); ++j) { - experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + if (pten_outs[j]) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]); + } } } } diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 9831c2628dc95988fc4cf24bfd6c665d283f8cbf..dddcd914ed28ae7ae3ecb43bdf5ffa32ebb6f8c5 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() { const auto& op_type = pair.first; const auto* op_proto = pair.second.proto_; if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && - op_proto != nullptr) { + op_proto) { KernelArgsNameMakerByOpProto maker(op_proto); VLOG(10) << "Register kernel signature for " << op_type; auto success = kernel_signature_map_->map_ diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c355ace528d42cf1db6c3e1df6030deae78e1c2a..1d12ecf30ede50a44fed5de4f285e466acecbe04 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext( for (size_t i = 0; i < output_names.size(); ++i) { auto& out_def = output_defs.at(i); - auto& outs_vector = outs.at(output_names[i]); size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); - size_t end_idx = start_idx + outs_vector.size(); auto current_vector_size = kernel_ctx->OutputsSize(); + + auto iter = outs.find(output_names[i]); + if (iter == outs.end()) { + if (current_vector_size > start_idx) { + kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr}); + } else { + kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); + } + kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1), + i); + continue; + } + + auto& outs_vector = iter->second; + size_t end_idx = start_idx + outs_vector.size(); + // If the memory needed is less than the current memory allocated, we will // reuse the current memory by using ReMakePtenDenseTensorFromVar. // Otherwise,we will create new storage. for (size_t offset = 0; offset < outs_vector.size(); ++offset) { if (current_vector_size > start_idx + offset) { - experimental::ReMakePtenDenseTensorFromVar( - outs_vector[offset]->MutableVar(), out_def, - kernel_ctx->MutableOutputAt(start_idx + offset)); + auto* buffer_tensor = + kernel_ctx->MutableOutputAt(start_idx + offset); + if (buffer_tensor) { + experimental::ReMakePtenDenseTensorFromVar( + outs_vector[offset]->MutableVar(), out_def, buffer_tensor); + } else { + kernel_ctx->SetOutputWithoutSetRange( + start_idx + offset, + experimental::MakePtenTensorBaseFromVar( + outs_vector[offset]->MutableVar(), out_def)); + } } else { kernel_ctx->EmplaceBackOutputWithoutSetRange( experimental::MakePtenTensorBaseFromVar( @@ -465,15 +487,18 @@ static void WriteBackToOutputs( auto& output_names = std::get<2>(pt_kernel_signature.args); for (size_t i = 0; i < output_names.size(); ++i) { - auto& outs_vector = outs.at(output_names[i]); + auto iter = outs.find(output_names[i]); + if (iter != outs.end()) { + auto& outs_vector = iter->second; - auto& range_pair = kernel_ctx->OutputRangeAt(i); - auto pten_outs = kernel_ctx->MutableOutputBetween( - range_pair.first, range_pair.second); + auto& range_pair = kernel_ctx->OutputRangeAt(i); + auto pten_outs = kernel_ctx->MutableOutputBetween( + range_pair.first, range_pair.second); - for (size_t j = 0; j < pten_outs.size(); ++j) { - experimental::MakeVariableFromPtenTensor(pten_outs[j], - outs_vector[j]->MutableVar()); + for (size_t j = 0; j < pten_outs.size(); ++j) { + experimental::MakeVariableFromPtenTensor(pten_outs[j], + outs_vector[j]->MutableVar()); + } } } } @@ -529,6 +554,7 @@ static void PreparedOpRunImpl( template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, + const framework::OpKernelType& kernel_type, const framework::KernelSignature& pt_kernel_signature, const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context, platform::DeviceContext* dev_ctx, const NameVarMap& ins, @@ -558,7 +584,9 @@ static void PreparedOpRunPtImpl( pt_kernel_context->ClearData(); // TODO(chenweihang): add debug flags later - // TODO(chenweihang): deal with complex cases later + if (framework::IsComplexType(kernel_type.data_type_)) { + HandleComplexGradToRealGrad(outs); + } } void PreparedOp::Run(const NameVarMap& ins, @@ -566,9 +594,9 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pten_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, - pt_kernel_context_, dev_ctx_, ins, outs, attrs, - default_attrs); + PreparedOpRunPtImpl(op_, kernel_type_, pt_kernel_signature_, + pt_kernel_, pt_kernel_context_, dev_ctx_, ins, + outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); @@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap& ins, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { if (run_pten_kernel_) { - PreparedOpRunPtImpl(op_, pt_kernel_signature_, pt_kernel_, - pt_kernel_context_, dev_ctx_, ins, - outs, attrs, default_attrs); + PreparedOpRunPtImpl( + op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_, + dev_ctx_, ins, outs, attrs, default_attrs); } else { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, outs, attrs, default_attrs); diff --git a/paddle/fluid/operators/conj_op.h b/paddle/fluid/operators/conj_op.h index 417a136c60b618d4418d072f31d12d6d2e175027..381f4cb66b3cd6d511bcd95d7cba5842023b501e 100644 --- a/paddle/fluid/operators/conj_op.h +++ b/paddle/fluid/operators/conj_op.h @@ -14,11 +14,13 @@ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/math/complex_functors.h" -#include "paddle/fluid/platform/for_range.h" + +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/complex_kernel.h" namespace paddle { namespace operators { @@ -30,16 +32,14 @@ class ConjKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { const Tensor* x = context.Input("X"); Tensor* out = context.Output("Out"); + out->mutable_data(context.GetPlace(), size_t(x->numel() * sizeof(T))); - auto numel = x->numel(); - auto* x_data = x->data(); - auto* out_data = out->mutable_data(context.GetPlace(), - size_t(x->numel() * sizeof(T))); + auto& dev_ctx = context.device_context(); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); - auto& dev_ctx = context.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - math::ConjFunctor functor(x_data, numel, out_data); - for_range(functor); + // call new kernel + pten::ConjKernel(dev_ctx, *pt_x.get(), pt_out.get()); } }; diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index 31acd9718115c78568326532e922aad543164732..e1463c8ccb58ebd4fe65bdcd2fb2a5e30def74e7 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel { ctx, framework::GradVarName("Out")), ctx.GetPlace()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "dot_grad", {"X", "Y", framework::GradVarName("Out")}, {}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } }; template diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index 09d607891b48542876a374cbf00db713befde4b2..02ba57ef8d495ff1e0b7c608bc3c3f55dd8c7163 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -19,257 +19,34 @@ #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/platform/for_range.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/dot_kernel.h" + namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -struct P { - void operator()(T a, R b); -}; - -template -struct DotGradFunction { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx); -}; - -template -struct DotGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), - tensor_dx->data()); - for_range(functor); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - - dx.device(dev) = dx * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), - tensor_dy->data()); - for_range(functor); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - - dy.device(dev) = dy * dout.broadcast(size); - } - } else { - auto dout = framework::EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = framework::EigenMatrix::From(*tensor_y); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - - paddle::platform::ForRange for_range(dev_raw, - tensor_y->numel()); - math::ConjFunctor functor(tensor_y->data(), tensor_y->numel(), - tensor_dx->data()); - for_range(functor); - auto dx = framework::EigenMatrix::From(*tensor_dx); - - dx.device(dev) = dx * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = framework::EigenMatrix::From(*tensor_x); - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - - paddle::platform::ForRange for_range(dev_raw, - tensor_x->numel()); - math::ConjFunctor functor(tensor_x->data(), tensor_x->numel(), - tensor_dy->data()); - for_range(functor); - - auto dy = framework::EigenMatrix::From(*tensor_dy); - - dy.device(dev) = dy * dout.broadcast(size); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_y = tensor_y->data(); - const framework::DDim& dim = tensor_x->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_x = tensor_x->data(); - const framework::DDim& dim = tensor_y->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; - } - } -#endif - } -}; - -template -struct DotGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - const Tensor* tensor_dout, Tensor* tensor_dx, - Tensor* tensor_dy, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto dout = framework::EigenVector::Flatten(*tensor_dout); - - if (tensor_dx) { - auto y = framework::EigenVector::Flatten(*tensor_y); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dx->numel()); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - auto x = framework::EigenVector::Flatten(*tensor_x); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(tensor_dy->numel()); - dy.device(dev) = x * dout.broadcast(size); - } - } else { - auto dout = framework::EigenMatrix::From(*tensor_dout); - - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto y = framework::EigenMatrix::From(*tensor_y); - auto dx = framework::EigenMatrix::From(*tensor_dx); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dx->dims()[1]); - dx.device(dev) = y * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto x = framework::EigenMatrix::From(*tensor_x); - auto dy = framework::EigenMatrix::From(*tensor_dy); - auto& dev = - *ctx.template device_context().eigen_device(); - Eigen::DSizes size(1, tensor_dy->dims()[1]); - dy.device(dev) = x * dout.broadcast(size); - } - } -#else - auto const *x = tensor_x->data(), *y = tensor_y->data(), - *dz = tensor_dout->data(); - auto&& d = tensor_x->dims(); - auto const N = tensor_x->numel(); - auto const B = d[d.size() - 1]; - - if (tensor_dx) { - auto* dx = tensor_dx->mutable_data(ctx.GetPlace()); - for (auto j = 0; j < N / B; ++j) { - auto const ss = dz[j]; - for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss; - } - } - - if (tensor_dy) { - auto* dy = tensor_dy->mutable_data(ctx.GetPlace()); - for (auto j = 0; j < N / B; ++j) { - auto const ss = dz[j]; - for (auto i = 0; i < B; i++) *dy++ = *x++ * ss; - } - } -#endif - } -}; - +// See Note [ Why still keep the original kernel implementation? ] template class DotKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* tensor_x = ctx.Input("X"); - auto* tensor_y = ctx.Input("Y"); - auto* tensor_out = ctx.Output("Out"); - tensor_out->mutable_data(ctx.GetPlace()); - -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_out->dims().size()) { - auto out = framework::EigenScalar::From(*tensor_out); - auto x = framework::EigenVector::Flatten(*tensor_x); - auto y = framework::EigenVector::Flatten(*tensor_y); - - auto& dev = *ctx.template device_context().eigen_device(); - out.device(dev) = (x * y).sum(); - } else { - auto out = framework::EigenMatrix::From(*tensor_out); - auto x = framework::EigenMatrix::From(*tensor_x); - auto y = framework::EigenMatrix::From(*tensor_y); - - auto& dev = *ctx.template device_context().eigen_device(); - out.device(dev) = (x * y).sum(Eigen::DSizes(1)); - } -#else - auto const *x = tensor_x->data(), *x_ = &x[0]; - auto const *y = tensor_y->data(), *y_ = &y[0]; - auto* z = tensor_out->data(); - - // Loop over the total N elements of both operands while sum-reducing every - // B pairs along the way where B is the dimension of the least ordered axis - auto&& d = tensor_x->dims(); - auto const N = tensor_x->numel(); - auto const B = d[d.size() - 1]; - - for (int j = 0; j < N / B; j++) { - T ss = 0; - for (int i = 0; i < B; i++) ss += (*x_++) * (*y_++); - z[j] = ss; - } -#endif + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + auto& dev_ctx = ctx.device_context(); + out->mutable_data(x->place()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); + + // call new kernel + pten::DotKernel(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); } }; @@ -286,8 +63,17 @@ class DotGradKernel : public framework::OpKernel { if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); - DotGradFunction()(tensor_x, tensor_y, tensor_dout, - tensor_dx, tensor_dy, ctx); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout); + auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx); + auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy); + + auto& dev_ctx = ctx.device_context(); + + // call new kernel + pten::DotGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(), + pt_dy.get()); } }; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 6546f854df0f4ca7f1e08f3f178ac5c836633312..2be7695e6a8c47a036ef95439a55230ae61ac7c4 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -225,6 +225,10 @@ class Blas { const framework::Tensor& mat_b, const MatDescriptor& dim_b, T alpha, framework::Tensor* mat_out, T beta) const; + template + void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b, + const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const; + template void VINV(int n, const T* a, T* y) const; @@ -253,6 +257,12 @@ class Blas { void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a, int lda, int* ipiv, T** b, int ldb, int* info, int batch_size) const; + + // cuBlas triangular_solve + template + void BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, int M, int N, T alpha, const T** a, int lda, + T** b, int ldb, int batch_size) const; #endif private: @@ -414,6 +424,12 @@ class BlasT : private Blas { void BatchedGETRS(ARGS... args) const { Base()->template BatchedGETRS(args...); } + + // triangular_solve + template + void BatchedTRSM(ARGS... args) const { + Base()->template BatchedTRSM(args...); + } #endif private: diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index cb4044b1b08c7a154a07ecf6b1cf58a84a46876a..be9cf1e3448b623b503138b74062983e3bdfc4d8 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -434,6 +434,17 @@ struct CBlas> { a_, lda, b_, ldb, &beta, c_, ldc); } + static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + paddle::platform::complex *B, int ldb) { + const void *a_ = (const void *)(A); + void *b_ = static_cast(B); + platform::dynload::cblas_ctrsm(layout, side, uplo, trans_a, diag, M, N, + &alpha, a_, lda, b_, ldb); + } + template static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, @@ -562,6 +573,17 @@ struct CBlas> { a_, lda, b_, ldb, &beta, c_, ldc); } + static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + paddle::platform::complex *B, int ldb) { + const void *a_ = (const void *)(A); + void *b_ = static_cast(B); + platform::dynload::cblas_ztrsm(layout, side, uplo, trans_a, diag, M, N, + &alpha, a_, lda, b_, ldb); + } + template static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, @@ -682,6 +704,15 @@ struct CBlas> { cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } + + static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, + const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, + const CBLAS_DIAG diag, const int M, const int N, + const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + paddle::platform::complex *B, const int ldb) { + cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); + } }; template <> @@ -720,6 +751,15 @@ struct CBlas> { cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } + + static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, + const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, + const CBLAS_DIAG diag, const int M, const int N, + const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + paddle::platform::complex *B, const int ldb) { + cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); + } }; #endif @@ -1209,6 +1249,15 @@ void Blas::MatMul(const framework::Tensor &mat_a, const framework::Tensor &mat_b, const MatDescriptor &dim_b, T alpha, framework::Tensor *mat_out, T beta) const { + MatMul(mat_a.data(), dim_a, mat_b.data(), dim_b, alpha, + mat_out->data(), beta); +} + +template +template +void Blas::MatMul(const T *mat_a, const MatDescriptor &dim_a, + const T *mat_b, const MatDescriptor &dim_b, + T alpha, T *mat_out, T beta) const { PADDLE_ENFORCE_EQ( dim_a.width_, dim_b.height_, platform::errors::InvalidArgument( @@ -1221,8 +1270,7 @@ void Blas::MatMul(const framework::Tensor &mat_a, CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_, alpha, mat_a.data(), - mat_b.data(), beta, mat_out->data()); + dim_a.width_, alpha, mat_a, mat_b, beta, mat_out); } else { PADDLE_ENFORCE_EQ( dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 || @@ -1233,8 +1281,8 @@ void Blas::MatMul(const framework::Tensor &mat_a, "But got dim_a.batch_size = %d, dim_b.batch_size = %d.", dim_a.batch_size_, dim_b.batch_size_)); this->template BatchedGEMM( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, - mat_a.data(), mat_b.data(), beta, mat_out->data(), + transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a, + mat_b, beta, mat_out, dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, dim_a.stride_, dim_b.stride_); } diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 1b609b15d6e56934a460b6d2ec249f7dc6a916d6..a5eca7b225558a455563a8cf548d74a251828d49 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -19,6 +19,81 @@ namespace paddle { namespace operators { +static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx, + const std::string input_name) { + auto shape = ctx.Attrs().Get>("fused_reshape_" + input_name); + auto axis = + ctx.Attrs().Get>("fused_transpose_" + input_name); + auto dim = ctx.GetInputDim(input_name); + + PADDLE_ENFORCE_GT(dim.size(), 0, + platform::errors::InvalidArgument( + "The Input(%s) has not been initialized properly. The " + "shape of Input(%s) = [%s].", + dim)); + + // if mkldnn reshape+transpose+matmul fuse activated + if (!shape.empty() && !axis.empty()) { + PADDLE_ENFORCE_GE( + shape.size(), 2, + platform::errors::InvalidArgument( + "shape_%s attribute of MatMulOp was implemented for 2, 3 " + "or 4 dimensions.", + input_name)); + PADDLE_ENFORCE_LE( + shape.size(), 4, + platform::errors::InvalidArgument( + "shape_%s attribute of MatMulOp was implemented for 2, 3 " + "or 4 dimensions.", + input_name)); + PADDLE_ENFORCE_EQ( + shape.size(), axis.size(), + platform::errors::InvalidArgument( + "Ranks of shape_%s and axis_%s attributes of MatMulOp " + "must be equal.", + input_name, input_name)); + + int num_negative = std::count(shape.begin(), shape.end(), -1); + PADDLE_ENFORCE_LE(num_negative, 1, + platform::errors::InvalidArgument( + "The max number of -1 in fused_reshape_%s is 1 " + "but received %d.", + input_name, num_negative)); + + auto it_zero = std::find(shape.begin(), shape.end(), 0); + if (it_zero != shape.end()) { + for (uint64_t i = 0; i < shape.size(); i++) { + if (shape[i] == 0) { + PADDLE_ENFORCE_LT(i, dim.size(), + platform::errors::InvalidArgument( + "The index of 0 in fused_reshape_%s ", + "should be less than output dim size, ", + "but the index is %d and output dim size is %d", + input_name, i, dim.size())); + shape[i] = dim.at(i); + } + } + } + + // if "-1" is present then one of reshape dims must be infered + auto it_negative = std::find(shape.begin(), shape.end(), -1); + if (it_negative != shape.end()) { + int64_t dim_product = 1; + for (int i = 0; i < dim.size(); i++) { + dim_product *= dim.at(i); + } + + int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1, + std::multiplies()); + int index = std::distance(shape.begin(), it_negative); + shape[index] = dim_product / shape_product; + } + + dim = dim.reshape(shape).transpose(axis); + } + return dim; +} + class MatMulV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -30,9 +105,9 @@ class MatMulV2Op : public framework::OperatorWithKernel { bool trans_y = ctx->Attrs().Get("trans_y"); std::vector dims_x = - paddle::framework::vectorize(ctx->GetInputDim("X")); + framework::vectorize(GetDimForInput(*ctx, "X")); std::vector dims_y = - paddle::framework::vectorize(ctx->GetInputDim("Y")); + framework::vectorize(GetDimForInput(*ctx, "Y")); auto ndims_x = dims_x.size(); auto ndims_y = dims_y.size(); PADDLE_ENFORCE_GT(ndims_x, 0, @@ -119,9 +194,32 @@ class MatMulV2Op : public framework::OperatorWithKernel { "received %d", reshape_out_size)); - auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); + // int num_negative = std::count(reshape_out.begin(), reshape_out.end(), + // -1); + // PADDLE_ENFORCE_LE(num_negative, 1, + // platform::errors::InvalidArgument( + // "The max number of -1 in fused_reshape_Out is 1 " + // "but received %d.", + // num_negative)); + + // auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0); + // if (it_zero != reshape_out.end()) { + // for (uint64_t i = 0; i < reshape_out.size(); i++) { + // if (reshape_out[i] == 0) { + // PADDLE_ENFORCE_LT( + // i, ddim_out.size(), + // platform::errors::InvalidArgument( + // "The index of 0 in fused_reshape_Out ", + // "should be less than output dim size, ", + // "but the index is %d and output dim size is %d", i, + // ddim_out.size())); + // reshape_out[i] = ddim_out.at(i); + // } + // } + // } // if "-1" is present then one of reshape dims must be infered + auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); if (it != reshape_out.end()) { int index = std::distance(reshape_out.begin(), it); @@ -215,6 +313,22 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault("float32") .InEnum({"float32", "bfloat16"}) .AsExtra(); + AddAttr>("fused_reshape_X", + R"DOC(Shape of fused reshape of `X` input.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>("fused_reshape_Y", + R"DOC(Shape of fused reshape of `Y` input.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>("fused_transpose_X", + R"DOC(Axis of fused transpose of `X` input.)DOC") + .SetDefault({}) + .AsExtra(); + AddAttr>("fused_transpose_Y", + R"DOC(Axis of fused transpose of `Y` input.)DOC") + .SetDefault({}) + .AsExtra(); AddComment( R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). @@ -275,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_grad", {"X", "Y", framework::GradVarName("Out")}, + {"trans_x", "trans_y"}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } }; template @@ -317,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel { context->ShareDim("DOut", "DDOut"); } } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"}, + {"trans_x", "trans_y"}, {"DX", "DY", "DDOut"}); + } }; template @@ -347,6 +476,85 @@ class MatMulV2OpDoubleGradMaker : public framework::SingleGradOpMaker { op->SetAttrMap(this->Attrs()); } }; +class MatMulV2OpTripleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("DDX"), "Input", "DDX", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("DDY"), "Input", "DDY", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("D_DX"), "Input", "D_DX", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("D_DY"), "Input", "D_DY", + "matmul_v2_triple_grad"); + OP_INOUT_CHECK(context->HasInput("D_DDOut"), "Input", "D_DDOut", + "matmul_v2_triple_grad"); + + if (context->HasOutput("D_X_out")) { + context->ShareDim("X", "D_X_out"); + } + if (context->HasOutput("D_Y_out")) { + context->ShareDim("Y", "D_Y_out"); + } + if (context->HasOutput("D_DOut_out")) { + context->ShareDim("DOut", "D_DOut_out"); + } + if (context->HasOutput("D_DDX_out")) { + context->ShareDim("X", "D_DDX_out"); + } + if (context->HasOutput("D_DDY_out")) { + context->ShareDim("Y", "D_DDY_out"); + } + } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + return framework::KernelSignature( + "matmul_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"trans_x", "trans_y"}, + {"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"}); + } +}; + +template +class MatMulV2OpTripleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("matmul_v2_triple_grad"); + + // get input from double grad + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("DOut", this->Input("DOut")); + op->SetInput("DDX", this->Input("DDX")); + op->SetInput("DDY", this->Input("DDY")); + op->SetInput("D_DX", this->OutputGrad("DX")); + op->SetInput("D_DY", this->OutputGrad("DY")); + op->SetInput("D_DDOut", this->OutputGrad("DDOut")); + + // set outputs + op->SetOutput("D_X_out", this->InputGrad("X")); + op->SetOutput("D_Y_out", this->InputGrad("Y")); + op->SetOutput("D_DOut_out", this->InputGrad("DOut")); + op->SetOutput("D_DDX_out", this->InputGrad("DDX")); + op->SetOutput("D_DDY_out", this->InputGrad("DDY")); + + op->SetAttrMap(this->Attrs()); + } +}; } // namespace operators } // namespace paddle @@ -359,7 +567,11 @@ REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad, ops::MatMulV2OpDoubleGradMaker, ops::MatMulV2OpDoubleGradMaker); -REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad); +REGISTER_OPERATOR(matmul_v2_grad_grad, ops::MatMulV2OpDoubleGrad, + ops::MatMulV2OpTripleGradMaker, + ops::MatMulV2OpTripleGradMaker); + +REGISTER_OPERATOR(matmul_v2_triple_grad, ops::MatMulV2OpTripleGrad); REGISTER_OP_CPU_KERNEL( matmul_v2, ops::MatMulV2Kernel, @@ -385,3 +597,12 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>, ops::MatMulV2DoubleGradKernel>); + +REGISTER_OP_CPU_KERNEL( + matmul_v2_triple_grad, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel, + ops::MatMulV2TripleGradKernel>, + ops::MatMulV2TripleGradKernel>); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index dd9940db29f7739b54a5fe26d89746f0eceb2b2c..e93bd212868fd0183bc85d55b75146e9e7ebd1ab 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -25,8 +25,14 @@ limitations under the License. */ #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/matmul_grad_kernel.h" +#include "paddle/pten/kernels/matmul_kernel.h" + #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif namespace paddle { @@ -34,343 +40,6 @@ namespace operators { using framework::Tensor; -struct IdentityFunctor { - HOSTDEVICE explicit inline IdentityFunctor() {} - - template - HOSTDEVICE inline U operator()(const U& x) const { - return x; - } -}; - -template -void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, - const std::vector& reduce_dims, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - auto stream = ctx.cuda_device_context().stream(); - TensorReduce(*input, output, reduce_dims, - static_cast(0), cub::Sum(), - IdentityFunctor(), stream); -#else - ReduceKernelFunctor( - input, output, reduce_dims, true, false, ctx) - .template apply(); -#endif -} - -static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims, - const int y_ndim, const std::int64_t* y_dims, - std::int64_t* x_bd_dims, - std::int64_t* y_bd_dims, - std::int64_t* out_bd_dims) { - const int ndim = (std::max)(x_ndim, y_ndim); - std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); - std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); - std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); - std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); - - for (int i = 0; i < ndim; ++i) { - PADDLE_ENFORCE_EQ( - x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, - true, - platform::errors::InvalidArgument( - "Input(X) and Input(Y) has error dim." - "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," - "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," - "But received X_broadcast's shape[%s] = [%s]" - "received Y_broadcast's shape[%s] = [%s]", - i, i, i, i, i, x_bd_dims[i], i, y_bd_dims[i])); - if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { - out_bd_dims[i] = 0; - } else { - out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); - } - } -} - -static int64_t GetIndexMessage(const int n, const int64_t* dims, - const int64_t* index) { - int64_t sum = 0; - for (int i = 0; i < n; ++i) { - if (dims[i] > 1) { - sum = sum * dims[i] + index[i]; - } - } - return sum; -} - -static void IndexIncreaseFromDims(const int ndim, const int64_t* dims, - int64_t* index) { - for (int i = ndim - 1; i >= 0; --i) { - ++index[i]; - if (index[i] >= dims[i]) { - index[i] -= dims[i]; - } else { - break; - } - } -} - -template -void MatMulFunction(const Tensor* X, const Tensor* Y, - const std::vector& x_dims, - const std::vector& y_dims, Tensor* Out, - bool trans_x, bool trans_y, - const paddle::framework::ExecutionContext& ctx, - bool flag = false) { - const int x_ndim = x_dims.size(); - const int y_ndim = y_dims.size(); - - // Get data ptr - const T* x_data = X->data(); - const T* y_data = Y->data(); - - if (x_ndim == 1 && y_ndim == 1) { - PADDLE_ENFORCE_EQ( - X->numel(), Y->numel(), - platform::errors::InvalidArgument( - "X's numbers must be equal to Y's numbers," - "when X/Y's dims =1. But received X has [%d] elements," - "received Y has [%d] elements", - X->numel(), Y->numel())); - VLOG(3) << "MatMul's case 1"; - Out->Resize({1}); - Out->mutable_data(ctx.GetPlace()); - auto out_eigen = framework::EigenScalar::From(*Out); - auto x_eigen = framework::EigenVector::Flatten(*X); - auto y_eigen = framework::EigenVector::Flatten(*Y); - - auto& dev = *ctx.template device_context().eigen_device(); - if (flag) { - out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; - } else { - out_eigen.device(dev) = (x_eigen * y_eigen).sum(); - } - return; - } - - auto& dev_ctx = ctx.template device_context(); - auto blas = math::GetBlas(dev_ctx); - - if (x_ndim == 1) { - const int N = X->numel(); - if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], N, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, N, y_ndim - 1, y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], N, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, N, y_ndim - 2, y_dims[y_ndim - 2])); - } - std::vector out_dims(y_ndim - 1); - if (trans_y) { - std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); - } else { - std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); - out_dims.back() = y_dims.back(); - } - Out->Resize(framework::make_ddim(out_dims)); - Out->mutable_data(ctx.GetPlace()); - if (trans_y) { - const int M = Y->numel() / N; - VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, M, N, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - const int M = y_dims[y_ndim - 1]; - const int batch_size = Y->numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, N, M, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - y_data, x_data, static_cast(flag), Out->data(), - batch_size, M * N, 0); - } - } - return; - } - - if (y_ndim == 1) { - const int N = Y->numel(); - if (trans_x) { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], N, - platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 2, N, x_ndim - 2, x_dims[x_ndim - 2])); - } else { - PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], N, - platform::errors::InvalidArgument( - "Input(X) has error dim." - "X'dims[%d] must be equal to %d" - "But received X'dims[%d] is %d", - x_ndim - 1, N, x_ndim - 1, x_dims[x_ndim - 1])); - } - std::vector out_dims(x_ndim - 1); - if (trans_x) { - std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); - out_dims.back() = x_dims.back(); - } else { - std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); - } - Out->Resize(framework::make_ddim(out_dims)); - Out->mutable_data(ctx.GetPlace()); - - if (trans_x) { - const int M = x_dims[x_ndim - 1]; - const int batch_size = X->numel() / (M * N); - if (batch_size == 1) { - VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, N, M, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), - x_data, y_data, static_cast(flag), Out->data(), - batch_size, M * N, 0); - } - } else { - const int M = X->numel() / N; - VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, M, N, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } - return; - } - - const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; - const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; - if (trans_y) { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1])); - } else { - PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, - platform::errors::InvalidArgument( - "Input(Y) has error dim." - "Y'dims[%d] must be equal to %d" - "But received Y'dims[%d] is %d", - y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2])); - } - const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - const int ndim = (std::max)(x_ndim, y_ndim); - std::vector x_broadcast_dims(ndim); - std::vector y_broadcast_dims(ndim); - std::vector out_broadcast_dims(ndim); - - GetBroadcastFromDims(x_ndim - 2, x_dims.data(), y_ndim - 2, y_dims.data(), - x_broadcast_dims.data(), y_broadcast_dims.data(), - out_broadcast_dims.data()); - - out_broadcast_dims[ndim - 2] = M; - out_broadcast_dims[ndim - 1] = N; - - Out->Resize(framework::make_ddim(out_broadcast_dims)); - Out->mutable_data(ctx.GetPlace()); - - const int batch_dim = ndim - 2; - // broadcast message - const bool is_broadcast_dims = !std::equal( - x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, - y_broadcast_dims.cbegin()); - - const std::int64_t x_batch_size = std::accumulate( - x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - const std::int64_t y_batch_size = std::accumulate( - y_broadcast_dims.cbegin(), y_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - const std::int64_t out_batch_size = std::accumulate( - out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL, - std::multiplies()); - if (out_batch_size == 0) return; - if (x_batch_size == 1 && y_batch_size == 1) { - VLOG(3) << "MatMul's case 8"; - blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast(1), - x_data, y_data, static_cast(flag), Out->data()); - } else if (x_batch_size == 1) { - if (M == 1 && trans_y) { - VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, y_batch_size * N, K, static_cast(1), y_data, x_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 10"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, 0, K * N); - } - } else if (y_batch_size == 1) { - if (!trans_x) { - VLOG(3) << "MatMul's case 11"; - blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, N, K, static_cast(1), x_data, y_data, - static_cast(flag), Out->data()); - } else { - VLOG(3) << "MatMul's case 12"; - blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, M * K, 0); - } - } else if (!is_broadcast_dims) { - VLOG(3) << "MatMul's case 13"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_data, y_data, static_cast(flag), - Out->data(), out_batch_size, M * K, K * N); - } else { - // in the case, can't use stridedgemm - std::vector x_ptr(out_batch_size); - std::vector y_ptr(out_batch_size); - std::vector out_ptr(out_batch_size); - std::vector index(batch_dim, 0); - for (std::int64_t i = 0; i < out_batch_size; ++i) { - // using the index to get offset - const std::int64_t x_index = - GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); - const std::int64_t y_index = - GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); - - x_ptr[i] = x_data + x_index * M * K; - y_ptr[i] = y_data + y_index * K * N; - out_ptr[i] = Out->data() + i * M * N; - IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); - } - VLOG(3) << "MatMul's case 14"; - blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, - static_cast(1), x_ptr.data(), y_ptr.data(), - static_cast(flag), out_ptr.data(), out_batch_size); - } -} - -template -void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x, - bool trans_y, - const paddle::framework::ExecutionContext& ctx, - bool flag = false) { - const std::vector x_dims = vectorize(X->dims()); - const std::vector y_dims = vectorize(Y->dims()); - MatMulFunction(X, Y, x_dims, y_dims, Out, trans_x, trans_y, - ctx, flag); -} - template class MatMulV2Kernel : public framework::OpKernel { public: @@ -380,15 +49,17 @@ class MatMulV2Kernel : public framework::OpKernel { auto* Out = ctx.Output("Out"); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); - PADDLE_ENFORCE_NE(framework::product(X->dims()), 0, - platform::errors::InvalidArgument( - "The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0, - platform::errors::InvalidArgument( - "The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - MatMulFunction(X, Y, Out, trans_x, trans_y, ctx); + + auto& dev_ctx = ctx.device_context(); + Out->mutable_data(X->place()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*X); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out); + + // call new kernel + pten::MatmulKernel(dev_ctx, *pt_x, *pt_y, trans_x, trans_y, + pt_out.get()); } }; @@ -403,26 +74,6 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) { return output; } -// Reshape a rank-3 tensor from P x M x N to M x (P * N). -// (Warning: This requires transposing data and writes into new memory.) -// Identity op if the tensor is not of rank 3. -template -static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, - const framework::Tensor& input) { - auto in_dims = input.dims(); - if (in_dims.size() != 3) { - return input; - } - framework::Tensor output; - output.Resize({in_dims[1], in_dims[0], in_dims[2]}); - output.mutable_data(context.GetPlace()); - std::vector axis = {1, 0, 2}; - math::Transpose trans; - trans(context, input, &output, axis); - output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); - return output; -} - /** * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * original x_dim is returned. @@ -485,585 +136,45 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, ReshapeTensorIntoMatrixSequence(y, mat_dim_y); } -template -struct ConjHelper { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - dst.set_layout(src.layout()); - dst.ShareDataWith(src); - return; - } - - const framework::ExecutionContext& ctx_; -}; - -template -struct ConjHelper> { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - auto* src_data = src.data>(); - auto* dst_data = dst.mutable_data>( - ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex))); - - platform::ForRange for_range( - ctx_.template device_context(), src.numel()); - math::ConjFunctor> functor( - src_data, src.numel(), dst_data); - for_range(functor); - return; - } - const framework::ExecutionContext& ctx_; -}; - -template -struct ConjHelper> { - explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} - - HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { - dst.Resize(src.dims()); - auto* src_data = src.data>(); - auto* dst_data = dst.mutable_data>( - ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex))); - - platform::ForRange for_range( - ctx_.template device_context(), src.numel()); - math::ConjFunctor> functor( - src_data, src.numel(), dst_data); - for_range(functor); - return; - } - const framework::ExecutionContext& ctx_; -}; - -template -struct DotDoubleGradFunction { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx); -}; - -template -struct DotDoubleGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - framework::Tensor tensor_dout_help; - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - if (tensor_dx || tensor_dy) { - tensor_dout_help.Resize(tensor_dout->dims()); - tensor_dout_help.mutable_data(ctx.GetPlace()); - paddle::platform::ForRange for_range( - dev_raw, tensor_dout->numel()); - math::ConjFunctor functor(tensor_dout->data(), - tensor_dout->numel(), - tensor_dout_help.data()); - for_range(functor); - } - if (tensor_dx) { - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - Eigen::DSizes size(tensor_ddy->numel()); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - auto dout = framework::EigenVector::Flatten(tensor_dout_help); - dx.device(dev) = ddy * dout.broadcast(size); - } - - if (tensor_dy) { - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - Eigen::DSizes size(tensor_ddx->numel()); - auto dy = framework::EigenVector::Flatten(*tensor_dy); - auto dout = framework::EigenVector::Flatten(tensor_dout_help); - dy.device(dev) = ddx * dout.broadcast(size); - } - - if (tensor_ddout) { - framework::Tensor tensor_x_help, tensor_y_help; - tensor_x_help.Resize(tensor_x->dims()); - tensor_x_help.mutable_data(ctx.GetPlace()); - tensor_y_help.Resize(tensor_y->dims()); - tensor_y_help.mutable_data(ctx.GetPlace()); - - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - paddle::platform::ForRange for_range(dev_raw, - tensor_x->numel()); - math::ConjFunctor functor_x(tensor_x->data(), tensor_x->numel(), - tensor_x_help.data()); - for_range(functor_x); - math::ConjFunctor functor_y(tensor_y->data(), tensor_y->numel(), - tensor_y_help.data()); - for_range(functor_y); - auto x = framework::EigenVector::Flatten(tensor_x_help); - auto y = framework::EigenVector::Flatten(tensor_y_help); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - auto ddout = framework::EigenVector::Flatten(*tensor_ddout); - ddout.device(dev) = (x * ddy + y * ddx).sum(); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_ddy = tensor_ddy->data(); - const framework::DDim& dim = tensor_dx->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_ddx = tensor_ddx->data(); - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; - } - } - - if (tensor_ddout) { - auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); - auto* data_x = tensor_x->data(); - auto* data_y = tensor_y->data(); - auto* data_ddx = tensor_ddx->data(); - auto* data_ddy = tensor_ddy->data(); - - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + - T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; - } else { - data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + - T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; - } - new_s = false; - } - } -#endif - } -}; - -template -struct DotDoubleGradFunction> { - void operator()(const Tensor* tensor_x, const Tensor* tensor_y, - Tensor* tensor_dx, Tensor* tensor_dy, - const Tensor* tensor_dout, const Tensor* tensor_ddx, - const Tensor* tensor_ddy, Tensor* tensor_ddout, - const paddle::framework::ExecutionContext& ctx) { -#if defined(__NVCC__) || defined(__HIPCC__) - if (1 == tensor_dout->dims().size()) { - auto& dev_raw = ctx.template device_context(); - auto& dev = *dev_raw.eigen_device(); - auto dout = framework::EigenVector::Flatten(*tensor_dout); - if (tensor_dx) { - tensor_dx->mutable_data(ctx.GetPlace()); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - Eigen::DSizes size(tensor_ddy->numel()); - auto dx = framework::EigenVector::Flatten(*tensor_dx); - dx.device(dev) = ddy * dout.broadcast(size); - } - - if (tensor_dy) { - tensor_dy->mutable_data(ctx.GetPlace()); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - Eigen::DSizes size(tensor_ddx->numel()); - - auto dy = framework::EigenVector::Flatten(*tensor_dy); - dy.device(dev) = ddx * dout.broadcast(size); - } - - if (tensor_ddout) { - tensor_ddout->mutable_data(ctx.GetPlace()); - auto x = framework::EigenVector::Flatten(*tensor_x); - auto y = framework::EigenVector::Flatten(*tensor_y); - auto ddx = framework::EigenVector::Flatten(*tensor_ddx); - auto ddy = framework::EigenVector::Flatten(*tensor_ddy); - auto ddout = framework::EigenVector::Flatten(*tensor_ddout); - ddout.device(dev) = (x * ddy + y * ddx).sum(); - } - } -#else - const auto* data_dout = tensor_dout->data(); - - if (tensor_dx) { - auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); - const auto* data_ddy = tensor_ddy->data(); - const framework::DDim& dim = tensor_dx->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dx[i] = data_dout[s] * data_ddy[i]; - } - } - - if (tensor_dy) { - auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); - const auto* data_ddx = tensor_ddx->data(); - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - - auto step = dim[dim.size() - 1]; - - int s = -1; - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) ++s; - data_dy[i] = data_dout[s] * data_ddx[i]; - } - } - - if (tensor_ddout) { - auto* data_ddout = tensor_ddout->mutable_data(ctx.GetPlace()); - auto* data_x = tensor_x->data(); - auto* data_y = tensor_y->data(); - auto* data_ddx = tensor_ddx->data(); - auto* data_ddy = tensor_ddy->data(); - - const framework::DDim& dim = tensor_dy->dims(); - size_t N = static_cast(framework::product(dim)); - auto step = dim[dim.size() - 1]; - int s = -1; - bool new_s = false; - - for (size_t i = 0; i < N; ++i) { - if (0 == i % step) { - ++s; - new_s = true; - } - if (new_s) { - data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; - } else { - data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; - } - new_s = false; - } - } -#endif - } -}; - template class MatMulV2GradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, - framework::Tensor* out) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(0)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out); - } - } - void Compute(const framework::ExecutionContext& ctx) const override { bool transpose_x = ctx.Attr("trans_x"); bool transpose_y = ctx.Attr("trans_y"); - auto x = *ctx.Input("X"); - auto y = *ctx.Input("Y"); - auto dout = *ctx.Input(framework::GradVarName("Out")); - - framework::Tensor y_conj(y.type()); - framework::Tensor x_conj(y.type()); - - // get dims - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - if (dx) dx->mutable_data(ctx.GetPlace()); - if (dy) dy->mutable_data(ctx.GetPlace()); - if (dout.numel() == 1) { - DotGradFunction()(&x, &y, &dout, dx, dy, ctx); - return; - } - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - // Case2: no broadcast or no batch size, it aims to speed and it is same as - // matmul in old version. - if (!is_broadcast) { - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - - // for complex - ConjHelper conj_helper(ctx); - conj_helper(y, y_conj); - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - - // for complex - ConjHelper conj_helper(ctx); - conj_helper(x, x_conj); - } - if (transpose_x && transpose_y) { - CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx); - CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy); - } else if (transpose_x) { - CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx); - CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy); - } else if (transpose_y) { - CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx); - CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy); - } else { - CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx); - CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy); - } - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - Tensor dx_help, dy_help; - - ConjHelper conj_helper(ctx); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - if (transpose_x) { - if (transpose_y) { - // X'Y': dA = Y'G', dB = G'X' - if (dx) - MatMulFunction(&y_conj, &dout, y_dims, dout_dims, - &dx_help, true, true, ctx); - if (dy) - MatMulFunction(&dout, &x_conj, dout_dims, x_dims, - &dy_help, true, true, ctx); - } else { - // X'Y: dX = YG', dY = XG - if (dx) - MatMulFunction(&y_conj, &dout, y_dims, dout_dims, - &dx_help, false, true, ctx); - if (dy) - MatMulFunction(&x_conj, &dout, x_dims, dout_dims, - &dy_help, false, false, ctx); - } - } else { - if (transpose_y) { - // XY': dX = GY, dY = G'X - if (dx) - MatMulFunction(&dout, &y_conj, dout_dims, y_dims, - &dx_help, false, false, ctx); - if (dy) - MatMulFunction(&dout, &x_conj, dout_dims, x_dims, - &dy_help, true, false, ctx); - } else { - // XY: dX = GY', dY = X'G - if (dx) - MatMulFunction(&dout, &y_conj, dout_dims, y_dims, - &dx_help, false, true, ctx); - if (dy) - MatMulFunction(&x_conj, &dout, x_dims, dout_dims, - &dy_help, true, false, ctx); - } - } - - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); - - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); + if (dx) dx->mutable_data(ctx.GetPlace()); + if (dy) dy->mutable_data(ctx.GetPlace()); - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_dx = dx ? paddle::experimental::MakePtenDenseTensor(*dx) + : std::unique_ptr(nullptr); + auto pt_dy = dy ? paddle::experimental::MakePtenDenseTensor(*dy) + : std::unique_ptr(nullptr); - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - ctx); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - ctx); - } - dy->Resize(y.dims()); - } + auto& dev_ctx = ctx.device_context(); - // Get the OutputGrad(out) - } + // call new kernel + pten::MatmulGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, transpose_x, + transpose_y, pt_dx.get(), pt_dy.get()); } }; template class MatMulV2DoubleGradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, framework::Tensor* out, - bool flag) const { - out->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); - auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); - auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); - if (a.dims().size() == 3 && b.dims().size() <= 2) { - // the transpose_X must be false, if is true, the transpose cost much time - if (!trans_a) { - mat_dim_a.height_ *= mat_dim_a.batch_size_; - mat_dim_a.batch_size_ = 0; - } - } - blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast(1), out, - static_cast(flag)); - } - - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, - bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out, bool flag) const { - if (out == nullptr) return; - bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && - out->dims().size() == 2; - if (!need_combine) { - MatMul(context, a, trans_a, b, trans_b, out, flag); - } else { - auto& ctx = context.template device_context(); - MatMul(context, is_fold_init_dims_a - ? FoldInitDims(a) - : FoldHeadAndLastDims(ctx, a), - trans_a, is_fold_init_dims_b - ? FoldInitDims(b) - : FoldHeadAndLastDims(ctx, b), - trans_b, out, flag); - } - } - void Compute(const framework::ExecutionContext& context) const override { - auto x = *context.Input("X"); - auto y = *context.Input("Y"); - auto dout = *context.Input("DOut"); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* dout = context.Input("DOut"); auto* ddx = context.Input("DDX"); auto* ddy = context.Input("DDY"); @@ -1074,263 +185,84 @@ class MatMulV2DoubleGradKernel : public framework::OpKernel { bool transpose_x = context.Attr("trans_x"); bool transpose_y = context.Attr("trans_y"); - // Get dims from the input x, y, output_grad - std::vector x_dims = vectorize(x.dims()); - std::vector y_dims = vectorize(y.dims()); - std::vector dout_dims = vectorize(dout.dims()); - framework::Tensor x_conj(x.type()); - framework::Tensor y_conj(y.type()); - framework::Tensor dout_conj(dout.type()); - - int x_ndim = x_dims.size(); - int y_ndim = y_dims.size(); - int ndim = dout_dims.size(); - - // Case1 : x's or y's dim = 1 - if (x_ndim == 1 && y_ndim == 1) { - DotDoubleGradFunction()(&x, &y, dx, dy, &dout, ddx, ddy, - ddout, context); - return; - } - - bool is_broadcast = true; - if (x_ndim <= 2 || y_ndim <= 2) { - is_broadcast = false; - } else if (x_ndim != y_ndim) { - is_broadcast = true; - } else { - is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, - y_dims.cbegin()); - } - - if (!is_broadcast) { - // Case2: no broadcast or no batch size - ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); - framework::DDim dx_dims; - - ConjHelper conj_helper(context); - if (dx) { - dx_dims = dx->dims(); - if (dx_dims != x.dims()) { - dx->Resize(x.dims()); - } - } - - framework::DDim dy_dims; - if (dy) { - dy_dims = dy->dims(); - if (dy_dims != y.dims()) { - dy->Resize(y.dims()); - } - } - - framework::DDim ddout_dims; - if (ddout) { - ddout_dims = ddout->dims(); - if (ddout_dims != dout.dims()) { - ddout->Resize(dout.dims()); - } - } - - if (ddx || ddy) { - ConjHelper conj_helper(context); - conj_helper(dout, dout_conj); - } - if (ddout) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - } - bool ddout_flag = false; - if (ddx) { - auto ddx_mat = *ddx; - if (ddx_mat.dims() != x.dims()) { - ddx_mat.Resize(x.dims()); - } - if (dy) { - if (transpose_x && transpose_y) { - // dy = dout' * ddx' - CalcInputGrad(context, dout_conj, true, true, ddx_mat, true, false, - dy, false); - } else if (transpose_x) { - // dy = ddx * dout - CalcInputGrad(context, ddx_mat, false, false, dout_conj, false, - true, dy, false); - } else if (transpose_y) { - // dy = dout' * ddx - CalcInputGrad(context, dout_conj, true, true, ddx_mat, false, true, - dy, false); - } else { - // dy = ddx' * dout - CalcInputGrad(context, ddx_mat, true, true, dout_conj, false, true, - dy, false); - } - } - - if (ddout) { - CalcInputGrad(context, ddx_mat, transpose_x, true, y_conj, - transpose_y, false, ddout, ddout_flag); - ddout_flag = true; - } - } - - if (ddy) { - auto ddy_mat = *ddy; - if (ddy_mat.dims() != y.dims()) { - ddy_mat.Resize(y.dims()); - } - if (dx) { - if (transpose_x && transpose_y) { - // dx = ddy' * dout' - CalcInputGrad(context, ddy_mat, true, true, dout_conj, true, false, - dx, false); - } else if (transpose_x) { - // dx = ddy * dout' - CalcInputGrad(context, ddy_mat, false, false, dout_conj, true, - false, dx, false); - } else if (transpose_y) { - // dx = dout * ddy - CalcInputGrad(context, dout_conj, false, false, ddy_mat, false, - true, dx, false); - } else { - // dx = dout * ddy' - CalcInputGrad(context, dout_conj, false, false, ddy_mat, true, - false, dx, false); - } - } - - if (ddout) { - CalcInputGrad(context, x_conj, transpose_x, true, ddy_mat, - transpose_y, false, ddout, ddout_flag); - } - } - - if (dx) { - if (dx_dims != x.dims()) { - dx->Resize(dx_dims); - } - } - - if (dy) { - if (dy_dims != y.dims()) { - dy->Resize(dy_dims); - } - } - - if (ddout) { - if (ddout_dims != dout.dims()) { - ddout->Resize(ddout_dims); - } - } - } else { - // Case3: broadcast. It need cost much time to reduce sum for the - // broadcast and wastes the memory. - // So we should avoid the case in reality. - VLOG(3) << "It need cost much time to reduce sum for the broadcast and " - "wastes the memory. So we should avoid the case in reality"; - framework::Tensor ddy_conj(ddx->type()); - framework::Tensor ddx_conj(ddy->type()); + if (dx) dx->mutable_data(context.GetPlace()); + if (dy) dy->mutable_data(context.GetPlace()); + if (ddout) ddout->mutable_data(context.GetPlace()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx); + auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy); + auto pt_dx = paddle::experimental::MakePtenDenseTensor(*dx); + auto pt_dy = paddle::experimental::MakePtenDenseTensor(*dy); + auto pt_ddout = paddle::experimental::MakePtenDenseTensor(*ddout); + + auto& dev_ctx = context.device_context(); + + // call new kernel + pten::MatmulDoubleGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, + *pt_ddy, transpose_x, transpose_y, + pt_dx.get(), pt_dy.get(), pt_ddout.get()); + } +}; - Tensor dx_help, dy_help; - if (dx || dy) { - ConjHelper conj_helper(context); - conj_helper(dout, dout_conj); - } - if (ddout) { - ConjHelper conj_helper(context); - conj_helper(x, x_conj); - conj_helper(y, y_conj); - } - if (transpose_x) { - if (transpose_y) { - if (dx) - MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, - &dx_help, true, true, context); - if (dy) - MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, - &dy_help, true, true, context); - } else { - if (dx) - MatMulFunction(ddy, &dout_conj, y_dims, dout_dims, - &dx_help, false, true, context); - if (dy) - MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, - &dy_help, false, false, context); - } - } else { - if (transpose_y) { - if (dx) - MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, - &dx_help, false, false, context); - if (dy) - MatMulFunction(&dout_conj, ddx, dout_dims, x_dims, - &dy_help, true, false, context); - } else { - if (dx) - MatMulFunction(&dout_conj, ddy, dout_dims, y_dims, - &dx_help, false, true, context); - if (dy) - MatMulFunction(ddx, &dout_conj, x_dims, dout_dims, - &dy_help, true, false, context); - } - } +template +class MatMulV2TripleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + // get input + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* dout = context.Input("DOut"); + auto* ddx = context.Input("DDX"); + auto* ddy = context.Input("DDY"); - // get help dims - const std::vector dx_help_dims = vectorize(dx_help.dims()); - const std::vector dy_help_dims = vectorize(dy_help.dims()); + auto* d_dx = context.Input("D_DX"); + auto* d_dy = context.Input("D_DY"); + auto* d_ddout = context.Input("D_DDOut"); - std::vector dx_broadcast_dims(ndim); - std::vector dy_broadcast_dims(ndim); + // get output + auto* out_d_x = context.Output("D_X_out"); + auto* out_d_y = context.Output("D_Y_out"); + auto* out_d_dout = context.Output("D_DOut_out"); - std::fill(dx_broadcast_dims.data(), - dx_broadcast_dims.data() + ndim - x_ndim, 1); - std::fill(dy_broadcast_dims.data(), - dy_broadcast_dims.data() + ndim - y_ndim, 1); - std::copy(x_dims.data(), x_dims.data() + x_ndim, - dx_broadcast_dims.data() + ndim - x_ndim); - std::copy(y_dims.data(), y_dims.data() + y_ndim, - dy_broadcast_dims.data() + ndim - y_ndim); + auto* out_d_ddx = context.Output("D_DDX_out"); + auto* out_d_ddy = context.Output("D_DDY_out"); - std::vector dx_reduce_dims; - std::vector dy_reduce_dims; - for (int idx = 0; idx <= ndim - 3; idx++) { - if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { - dx_reduce_dims.push_back(idx); - } - if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { - dy_reduce_dims.push_back(idx); - } - } - // Reduce sum to get grad by ReduceSum - if (dx) { - if (dx_reduce_dims.empty()) { - *dx = std::move(dx_help); - } else { - ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, - context); - } - dx->Resize(x.dims()); - } - if (dy) { - if (dy_reduce_dims.empty()) { - *dy = std::move(dy_help); - } else { - ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, - context); - } - dy->Resize(y.dims()); - } + bool transpose_x = context.Attr("trans_x"); + bool transpose_y = context.Attr("trans_y"); - if (ddout) { - // Caluate the gradient of OutputGrad(Out) - MatMulFunction(ddx, &y_conj, x_dims, y_dims, ddout, - transpose_x, transpose_y, context); - MatMulFunction(&x_conj, ddy, x_dims, y_dims, ddout, - transpose_x, transpose_y, context, - true); - } - } + if (out_d_x) out_d_x->mutable_data(context.GetPlace()); + if (out_d_y) out_d_y->mutable_data(context.GetPlace()); + if (out_d_dout) out_d_dout->mutable_data(context.GetPlace()); + if (out_d_ddx) out_d_ddx->mutable_data(context.GetPlace()); + if (out_d_ddy) out_d_ddy->mutable_data(context.GetPlace()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); + auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout); + auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx); + auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy); + auto pt_d_dx = paddle::experimental::MakePtenDenseTensor(*d_dx); + auto pt_d_dy = paddle::experimental::MakePtenDenseTensor(*d_dy); + auto pt_d_ddout = paddle::experimental::MakePtenDenseTensor(*d_ddout); + + auto pt_out_d_x = paddle::experimental::MakePtenDenseTensor(*out_d_x); + auto pt_out_d_y = paddle::experimental::MakePtenDenseTensor(*out_d_y); + auto pt_out_d_dout = paddle::experimental::MakePtenDenseTensor(*out_d_dout); + auto pt_out_d_ddx = paddle::experimental::MakePtenDenseTensor(*out_d_ddx); + auto pt_out_d_ddy = paddle::experimental::MakePtenDenseTensor(*out_d_ddy); + + auto& dev_ctx = context.device_context(); + // call new kernel + pten::MatmulTripleGradKernel(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx, + *pt_ddy, *pt_d_dx, *pt_d_dy, *pt_d_ddout, + transpose_x, transpose_y, pt_out_d_x.get(), + pt_out_d_y.get(), pt_out_d_dout.get(), + pt_out_d_ddx.get(), pt_out_d_ddy.get()); } }; + } // namespace operators } // namespace paddle diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index 1b4254ad2c1038527d78dea6de106bb098abbc55..0b5f5cb18e13dbd79166e5f1a96608eb9a9411dc 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -70,6 +70,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { return *this; } +DenseTensor& DenseTensor::operator=(DenseTensor&& other) { + meta_ = std::move(other.meta_); + storage_.swap(other.storage_); + return *this; +} + int64_t DenseTensor::numel() const { if (meta_.is_scalar) { return 1; diff --git a/paddle/pten/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h index fc92e84f52cea7148aab3f2e35855586bb4ab361..1502accd197be6fddd1d0849e7373bebea7adf8b 100644 --- a/paddle/pten/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -97,6 +97,8 @@ class DenseTensor : public TensorBase, /// \brief DenseTensor shallow copy assignment. DenseTensor& operator=(const DenseTensor& other); + DenseTensor& operator=(DenseTensor&& other); + /// \brief Destroy the tensor object and release exclusive resources. virtual ~DenseTensor() = default; diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 56f7eea7ea802dd94d4c5aecf82732dae27d3b8b..46fa6dd376ee385d695b8674ecae2064d1df0a08 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -29,6 +29,9 @@ const std::unordered_map kernel_alias_name_map = { {"flatten_contiguous_range", "flatten"}, {"flatten_contiguous_range_grad", "flatten_grad"}, {"matmul_v2", "matmul"}, + {"matmul_v2_grad", "matmul_grad"}, + {"matmul_v2_grad_grad", "matmul_double_grad"}, + {"matmul_v2_triple_grad", "matmul_triple_grad"}, {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, @@ -36,6 +39,8 @@ const std::unordered_map kernel_alias_name_map = { {"flatten", "deprecated"}, {"flatten_grad", "deprecated"}, {"matmul", "deprecated"}, + {"matmul_grad", "deprecated"}, + {"matmul_grad_grad", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, {"sum", "deprecated"}}; diff --git a/paddle/pten/core/kernel_context.cc b/paddle/pten/core/kernel_context.cc index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..74bd6d17f066a2b569907cd933a558e505d6cf87 100644 --- a/paddle/pten/core/kernel_context.cc +++ b/paddle/pten/core/kernel_context.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/kernel_context.h" + +namespace pten { + +void KernelContext::EmplaceBackInput(std::shared_ptr input) { + int index = inputs_.size(); + inputs_.emplace_back(std::move(input)); + // Record the start and end index of the input + input_range_.emplace_back(std::pair(index, index + 1)); +} + +void KernelContext::EmplaceBackInputWithoutSetRange( + std::shared_ptr input) { + inputs_.emplace_back(std::move(input)); +} + +void KernelContext::EmplaceBackInputs( + paddle::SmallVector> inputs) { + int index = inputs_.size(); + // Record the start and end index of the input + input_range_.emplace_back(std::pair(index, index + inputs.size())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} + +void KernelContext::EmplaceBackOutput(std::shared_ptr output) { + int index = outputs_.size(); + outputs_.emplace_back(std::move(output)); + // Record the start and end index of the input + output_range_.emplace_back(std::pair(index, index + 1)); +} + +void KernelContext::EmplaceBackOutputWithoutSetRange( + std::shared_ptr output) { + outputs_.emplace_back(std::move(output)); +} + +void KernelContext::SetOutputWithoutSetRange( + int index, std::shared_ptr output) { + outputs_.at(index) = std::move(output); +} + +void KernelContext::EmplaceBackOutputs( + paddle::SmallVector> outputs) { + int index = outputs_.size(); + // Record the start and end index of the input + output_range_.emplace_back( + std::pair(index, index + outputs.size())); + outputs_.insert(outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + +void KernelContext::EmplaceBackAttr(paddle::any attr) { + attrs_.emplace_back(std::move(attr)); +} + +void KernelContext::AssignInputRange(std::pair&& range, size_t idx) { + if (idx < input_range_.size()) { + input_range_[idx] = range; + } else if (idx == input_range_.size()) { + input_range_.emplace_back(range); + } else { + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Invalid idx when trying to set InputRange, " + "index is `%d`, it is greater than the size(%d) of InputRange.", + idx, + input_range_.size())); + } +} + +void KernelContext::AssignOutputRange(std::pair&& range, size_t idx) { + if (idx < output_range_.size()) { + output_range_[idx] = range; + } else if (idx == output_range_.size()) { + output_range_.emplace_back(range); + } else { + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Invalid idx when trying to set InputRange, " + "index is `%d`, it is greater than the size(%d) of InputRange.", + idx, + output_range_.size())); + } +} + +const std::pair& KernelContext::InputRangeAt(size_t idx) const { + return input_range_.at(idx); +} + +const std::pair& KernelContext::OutputRangeAt(size_t idx) const { + return output_range_.at(idx); +} + +std::pair& KernelContext::MutableInputRangeAt(size_t idx) { + return input_range_[idx]; +} + +std::pair& KernelContext::MutableOutputRangeAt(size_t idx) { + return output_range_[idx]; +} + +// Temporary method: For compatible with fluid Tensor and improve performance +// Only deal with DenseTensor now +void KernelContext::ClearData() { + for (auto& in : inputs_) { + if (in) { + CompatibleDenseTensorUtils::ClearStorage( + static_cast(in.get())); + } + } + for (auto& out : outputs_) { + if (out) { + CompatibleDenseTensorUtils::ClearStorage( + static_cast(out.get())); + } + } + attrs_.clear(); +} +} // namespace pten diff --git a/paddle/pten/core/kernel_context.h b/paddle/pten/core/kernel_context.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..b6cc15c084ac0294add16ef8466a8259de1ab9f4 100644 --- a/paddle/pten/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -0,0 +1,165 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/pten/core/compat_utils.h" +#include "paddle/pten/core/tensor_base.h" +#include "paddle/utils/any.h" +#include "paddle/utils/small_vector.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +using DeviceContext = paddle::platform::DeviceContext; +using DataType = paddle::experimental::DataType; +using DataLayout = paddle::experimental::DataLayout; + +/** + * Note: KernelContext doesn't manage the life if DeviceContext and Tensor + * + * Note: KernelContext does not couple the concept of framework, + * its constructor can only take the members it needs as parameters, + * not Scope, RuntimeContext, etc. as parameters + */ +class KernelContext { + public: + KernelContext() = default; + explicit KernelContext(DeviceContext* dev_ctx) : dev_ctx_(dev_ctx) {} + + void SetDeviceContext(DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + + template + const CtxType& GetDeviceContext() const { + return static_cast(*dev_ctx_); + } + + void EmplaceBackInput(std::shared_ptr input); + + void EmplaceBackInputWithoutSetRange(std::shared_ptr input); + + void EmplaceBackInputs( + paddle::SmallVector> inputs); + + void EmplaceBackOutput(std::shared_ptr output); + + void EmplaceBackOutputWithoutSetRange(std::shared_ptr output); + + void SetOutputWithoutSetRange(int index, std::shared_ptr output); + + void EmplaceBackOutputs( + paddle::SmallVector> outputs); + + void EmplaceBackAttr(paddle::any attr); + + const std::pair& InputRangeAt(size_t idx) const; + + const std::pair& OutputRangeAt(size_t idx) const; + + std::pair& MutableInputRangeAt(size_t idx); + + std::pair& MutableOutputRangeAt(size_t idx); + + template + const TensorType& InputAt(size_t idx) const { + return static_cast(*(inputs_.at(idx))); + } + + template + paddle::optional OptionalInputAt(size_t idx) const { + const auto& input = inputs_.at(idx); + return input ? paddle::optional{static_cast< + const TensorType&>(*input)} + : paddle::optional{paddle::none}; + } + + std::shared_ptr& MutableInputPtrAt(size_t idx) { + return inputs_.at(idx); + } + + template + std::vector MoveInputsBetween(size_t start, size_t end) { + std::vector v; + for (size_t i = start; i < end; ++i) { + auto t = std::dynamic_pointer_cast(inputs_.at(i)); + v.emplace_back(std::move(*t.get())); + inputs_.at(i) = nullptr; + } + return v; + } + + void AssignInputRange(std::pair&& range, size_t idx); + + void AssignOutputRange(std::pair&& range, size_t idx); + + template + TensorType* MutableInputAt(size_t idx) { + return static_cast(inputs_.at(idx).get()); + } + + template + TensorType* MutableOutputAt(size_t idx) { + return static_cast(outputs_.at(idx).get()); + } + + template + std::vector MutableOutputBetween(size_t start, size_t end) { + std::vector v; + for (size_t i = start; i < end; ++i) { + v.emplace_back(static_cast(outputs_.at(i).get())); + } + + return v; + } + + template + AttrType AttrAt(size_t idx) const { + try { + return paddle::any_cast(attrs_.at(idx)); + } catch (paddle::bad_any_cast&) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Attribute cast error in Op Kernel Context.")); + } + } + + // Temporary method: For compatible with fluid Tensor and improve performance + // Only deal with DenseTensor now + void ClearData(); + + size_t InputsSize() const { return inputs_.size(); } + size_t OutputsSize() const { return outputs_.size(); } + size_t AttrsSize() const { return attrs_.size(); } + + private: + // DeviceContext base class + DeviceContext* dev_ctx_; + + // TODO(chenweihang): Tensor -> Tensor*, Tensor should by managed `scope` + // Note: can't use API Tensor here, the inference don't use this API Tensor + paddle::SmallVector> inputs_; + paddle::SmallVector> outputs_; + paddle::SmallVector attrs_; + + // Only contains input like list[Tensor] need `range` + paddle::SmallVector> input_range_; + paddle::SmallVector> output_range_; +}; + +} // namespace pten diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f08ef4acfd9ce713b53c71006ebafdee114112f0 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -0,0 +1,1398 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/pten/core/kernel_def.h" +#include "paddle/pten/core/kernel_factory.h" +#include "paddle/pten/core/kernel_utils.h" + +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +#define BACKEND(arg__) pten::Backend::arg__ +#define DATALAYOUT(arg__) pten::DataLayout::arg__ +#define DATATYPE(arg__) pten::DataType::arg__ + +template +struct KernelArgsParseFunctor; + +template +struct KernelArgsParseFunctor { + using Args = std::tuple; + enum : std::size_t { Arity = sizeof...(Args_) }; + using Indices = std::make_index_sequence; + template + using Arg = typename std::tuple_element::type; + + static void Parse(const KernelKey& default_key, KernelArgsDef* args_def) { + // TODO(chenweihang): The fluid Tensor's default layout is NCHW, + // it is not same as kernel's layout, we should fix this error on + // fluid Tensor + auto default_tensor_layout = pten::DataLayout::NCHW; + if (default_key.layout() != pten::DataLayout::ANY) { + default_tensor_layout = default_key.layout(); + } + auto args_type = ParseArgType(Indices{}); + for (auto arg_type : args_type) { + if (arg_type == std::type_index(typeid(const CPUContext&)) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + || + arg_type == std::type_index(typeid(const GPUContext&))) { +#else + ) { +#endif + // do nothing, skip context arg now + } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { + args_def->AppendInput( + default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == std::type_index(typeid( + paddle::optional))) { + args_def->AppendInput( + default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendInput( + default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == std::type_index(typeid(DenseTensor*))) { + args_def->AppendOutput( + default_key.backend(), default_tensor_layout, default_key.dtype()); + } else if (arg_type == + std::type_index(typeid(std::vector))) { + args_def->AppendOutput( + default_key.backend(), default_tensor_layout, default_key.dtype()); + } else { + // Attribute deal with + // TODO(chenweihang): now here allow any types of attribute, maybe + // should add limits here + args_def->AppendAttribute(arg_type); + } + } + } + + private: + template + static std::vector ParseArgType( + std::index_sequence) { + return {std::type_index(typeid(Arg))...}; + } +}; + +// TODO(chenweihang): Polish the kernel selection logic, support the selection +// of ALL_DTYPE kernel, and simplify the constructor +struct KernelRegistrar { + public: + KernelRegistrar(const char* kernel_name_cstr, + Backend backend, + DataLayout layout, + DataType dtype, + KernelArgsParseFn args_parse_fn, + KernelArgsDefFn args_def_fn, + KernelFn kernel_fn, + void* variadic_kernel_fn) { + ConstructKernel(kernel_name_cstr, + backend, + layout, + dtype, + args_parse_fn, + args_def_fn, + kernel_fn, + variadic_kernel_fn); + } + + KernelRegistrar(const char* kernel_name_cstr, + Backend backend, + DataLayout layout, + KernelArgsParseFn args_parse_fn, + KernelArgsDefFn args_def_fn, + KernelFn kernel_fn, + void* variadic_kernel_fn) { + for (size_t dtype = static_cast(DataType::BOOL); + dtype != static_cast(DataType::NUM_DATA_TYPES); + dtype++) { + ConstructKernel(kernel_name_cstr, + backend, + layout, + static_cast(dtype), + args_parse_fn, + args_def_fn, + kernel_fn, + variadic_kernel_fn); + } + } + + private: + void ConstructKernel(const char* kernel_name_cstr, + Backend backend, + DataLayout layout, + DataType dtype, + KernelArgsParseFn args_parse_fn, + KernelArgsDefFn args_def_fn, + KernelFn kernel_fn, + void* variadic_kernel_fn) { + std::string kernel_name(kernel_name_cstr); + KernelKey kernel_key(backend, layout, dtype); + Kernel kernel(kernel_fn, variadic_kernel_fn); + args_parse_fn(kernel_key, kernel.mutable_args_def()); + args_def_fn(&kernel); + KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; + } +}; + +#define PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) + +#define _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +#ifdef __COUNTER__ +#define PT_ID __COUNTER__ +#else +#define PT_ID __LINE__ +#endif + +#if defined(_WIN32) +#define UNUSED +#define __builtin_expect(EXP, C) (EXP) +#else +#define UNUSED __attribute__((unused)) +#endif + +#define PT_CONCATENATE(arg1, arg2) PT_CONCATENATE1(arg1, arg2) +#define PT_CONCATENATE1(arg1, arg2) PT_CONCATENATE2(arg1, arg2) +#define PT_CONCATENATE2(arg1, arg2) arg1##arg2 +#define PT_EXPAND(x) x + +/** + * Reference: + * + * https://stackoverflow.com/questions/1872220/is-it-possible-to-iterate-over-arguments-in-variadic-macros + * https://stackoverflow.com/questions/9183993/msvc-variadic-macro-expansion?rq=1 + * https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly + * + * Very carefully tiptoeing around an MSVC bug where it improperly expands + * __VA_ARGS__ as a single token in argument lists. See these URLs for details: + * + * http://connect.microsoft.com/VisualStudio/feedback/details/380090/variadic-macro-replacement + * http://cplusplus.co.il/2010/07/17/variadic-macro-to-count-number-of-arguments/#comment-644 + */ +#define PT_NARGS(...) _PT_NARGS((__VA_ARGS__, _PT_RESQ_N())) +#define _PT_NARGS(...) _PT_ARG_N(__VA_ARGS__) +#define _PT_ARG_N_EXPAND( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \ + N +#define _PT_ARG_N(args) _PT_ARG_N_EXPAND args +#define _PT_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + +/** PT_REGISTER_KERNEL + * + * The most frequently used kernel registration macro, used for kernel + * registration with only data type as template parameter, and the function + * pointer of the corresponding data type is automatically instantiated + * during registration. + * + * Note: `1TA` means `1 template argument` + */ +#define PT_REGISTER_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + _PT_REGISTER_1TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) + +#ifndef _WIN32 +#define _PT_REGISTER_1TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#else +/** + * `template decltype(fn) fn` can work on gcc and clang, + * but msvc will failed, error like: + * + * error C2206: typedef cannot be used for function definition + * + * reference: + * + * https://stackoverflow.com/questions/63989585/explicit-instantiation-of-function-using-decltype-work-on-g-but-not-on-visua + * + * And msvc can work without template instantiation + */ +#define _PT_REGISTER_1TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#endif + +#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ + _PT_KERNEL_INSTANTIATION(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, cpp_dtype, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, cpp_dtype, __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn +#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, __VA_ARGS__)) + +#define PT_KERNEL_REGISTRAR_INIT( \ + kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ + _PT_KERNEL_REGISTRAR_INIT(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +// clang-format off + +/* The =pre-commit always treats this macro into the wrong format, + and multi-line macros cannot be skipped with NOLINT.*/ +#define _PT_KERNEL_REGISTRAR_INIT(N, \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ + kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +// clang-format on + +#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } +#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) + +/** PT_REGISTER_NO_TEMPLATE_KERNEL + * + * Basic Kernel register marco, used to register a no template argument kernel + * function, pass in the complete function pointe of the kernel, this + * registration macro will not do automatic template instantiation. + * + * Note: developer maybe register 2 kernel with same name, backend and diff + * layout, so the layout also need to be a part of symbol var name. If developer + * register 2 kernel with same name, backend, layout and diff dtype, he should + * use another register marco PT_REGISTER_KERNEL. + * + * TODO(chenweihang): remove this marco later + */ +#define PT_REGISTER_NO_TEMPLATE_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + PT_KERNEL(kernel_fn), \ + PT_VARIADIC_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ + return 0; \ + } \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) + +/** PT_REGISTER_GENERAL_KERNEL + * + * Basic Kernel register marco, used to register a instantiated kernel function + * with one template argument. + */ + +#define PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ + _PT_REGISTER_GENERAL_KERNEL(kernel_name, backend, layout, kernel_fn, dtype) + +#ifndef _WIN32 +#define _PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + template decltype(kernel_fn) kernel_fn; \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + PT_KERNEL(kernel_fn), \ + PT_VARIADIC_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ + return 0; \ + } \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#else +#define _PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + static const ::pten::KernelRegistrar \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + PT_KERNEL(kernel_fn), \ + PT_VARIADIC_KERNEL(kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { \ + return 0; \ + } \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#endif + +/** PT_REGISTER_CTX_KERNEL + * + * Used for kernel registration with device context and data type as + * template parameter. + */ +#define PT_REGISTER_CTX_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_CTX_KERNEL must be called in global namespace."); \ + _PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__) + +#ifndef _WIN32 +#define _PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT2( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#else +#define _PT_REGISTER_2TA_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ + static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel*); \ + PT_KERNEL_REGISTRAR_INIT2( \ + kernel_name, \ + backend, \ + layout, \ + &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__); \ + void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ + ::pten::Kernel* kernel) +#endif + +#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \ + _PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + meta_kernel_fn, \ + backend, \ + cpp_dtype, \ + __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \ + (meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn +#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__)) + +#define PT_KERNEL_REGISTRAR_INIT2( \ + kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ + _PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +// clang-format off + +/* The =pre-commit always treats this macro into the wrong format, + and multi-line macros cannot be skipped with NOLINT.*/ +#define _PT_KERNEL_REGISTRAR_INIT2(N, \ + kernel_name, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \ + kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + __VA_ARGS__) + +// clang-format on + +#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } +#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) +#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \ + backend, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + #kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pten::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \ + backend, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + __VA_ARGS__)) + +/** PT_DECLARE_KERNEL + * + * Used to export the symbols of the file where the kernel is located, + * to avoid being removed by linker + */ +#define PT_DECLARE_KERNEL(kernel_name, backend, layout) \ + extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \ + UNUSED static int \ + __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \ + TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() + +} // namespace pten diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..60201151c62a23130878d93cc0992f9b6e79c02e 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -0,0 +1,254 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_context.h" +#include "paddle/pten/core/kernel_def.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +#define PT_KERNEL(...) \ + ::pten::KernelImpl::Compute + +#define PT_VARIADIC_KERNEL(...) \ + reinterpret_cast(&::pten::KernelImpl::VariadicCompute) + +#define PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(in_idx == 0, \ + "Kernel's DeviceContext should appear before Inputs."); \ + static_assert( \ + attr_idx == 0, \ + "Kernel's DeviceContext should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's DeviceContext should appear before Outputs."); \ + const dev_ctx& arg = ctx->GetDeviceContext(); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(tensor_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + const tensor_type& arg = ctx->InputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \ + template \ + struct KernelCallHelper, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + auto arg = ctx->OptionalInputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ + template \ + struct KernelCallHelper&, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(attr_idx == 0, \ + "Kernel's Input should appear before Attributes."); \ + static_assert(out_idx == 0, \ + "Kernel's Input should appear before Outputs."); \ + const std::pair range = ctx->InputRangeAt(in_idx); \ + std::vector arg = std::move( \ + ctx->MoveInputsBetween(range.first, range.second)); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "Kernel's Attributes should appear before Outputs."); \ + attr_type arg = ctx->AttrAt(attr_idx); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ + template \ + struct KernelCallHelper { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + const std::pair range = ctx->OutputRangeAt(out_idx); \ + tensor_type* arg = ctx->MutableOutputAt(range.first); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ + template \ + struct KernelCallHelper, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + const std::pair range = ctx->OutputRangeAt(out_idx); \ + std::vector arg = std::move( \ + ctx->MutableOutputBetween(range.first, range.second)); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ + } + +template +struct TypeTag {}; + +template +struct KernelImpl; + +template +struct KernelImpl { + static void Compute(KernelContext* ctx) { + KernelCallHelper>::template Compute<0, 0, 0, 0>(ctx); + } + + static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) { + return kernel_fn(static_cast(dev_ctx), std::forward(args)...); + } + + private: + template + struct KernelCallHelper; + + /* DeviceContext Helpers */ + + PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CPUContext); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(GPUContext); +#endif +#ifdef PADDLE_WITH_XPU + PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext); +#endif + + /* Input Helpers */ + + PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); + // TODO(chenweihang): adapt SelectedRows + // PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor); + + /* Attribute Helpers */ + + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(float); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + + /* Output Helpers */ + + PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); + // TODO(chenweihang): adapt SelectedRows + // PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor); + + /* End case */ + template + struct KernelCallHelper> { + template + static void Compute(KernelContext* ctx, DevCtx dev_ctx, Args&... args) { + static_assert(dev_ctx_idx > 0, + "Kernel should pass DeviceContext as argument."); + static_assert(out_idx > 0, "Kernel should have output argument."); + // TODO(chenweihang): check dev_ctx, in, attr, out number + return kernel_fn(dev_ctx, args...); + } + }; +}; + +} // namespace pten diff --git a/paddle/pten/include/linalg.h b/paddle/pten/include/linalg.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..71bc518aa89f8f4a2aeeda67448ed171eaf94265 100644 --- a/paddle/pten/include/linalg.h +++ b/paddle/pten/include/linalg.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// See Note: [ How do we organize the kernel directory ] +#include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/dot_kernel.h" + +namespace pten { + +template +DenseTensor Dot(const ContextT& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + auto out_meta = DotInferMeta(x.meta(), y.meta()); + pten::DenseTensor dense_out( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(out_meta)); + DotKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + +} // namespace pten diff --git a/paddle/pten/include/math.h b/paddle/pten/include/math.h index faa4c8db8dac319b653c9d9ebc067625004fbbe7..5070d0d4e0e5a29468b064641fbb9421ac13dec6 100644 --- a/paddle/pten/include/math.h +++ b/paddle/pten/include/math.h @@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx, return dense_out; } -template -DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) { - auto out_meta = UnchangedInferMeta(x.meta()); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Conj(dev_ctx, x, &dense_out); - return dense_out; -} - } // namespace pten diff --git a/paddle/pten/kernels/complex_kernel.h b/paddle/pten/kernels/complex_kernel.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..e9f717152a458c90ba1c53f77f53cb3d0f0ce611 100644 --- a/paddle/pten/kernels/complex_kernel.h +++ b/paddle/pten/kernels/complex_kernel.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/empty_kernel.h" + +namespace pten { + +template +void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + +template +DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { + auto out_meta = UnchangedInferMeta(x.meta()); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + ConjKernel(dev_ctx, x, &dense_out); + return dense_out; +} + +} // namespace pten diff --git a/paddle/pten/kernels/cpu/complex_kernel.cc b/paddle/pten/kernels/cpu/complex_kernel.cc index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..10e7e684db3c1a57badc92ecd0eb9b93f6bb0d6a 100644 --- a/paddle/pten/kernels/cpu/complex_kernel.cc +++ b/paddle/pten/kernels/cpu/complex_kernel.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/complex_kernel.h" +#include "paddle/pten/kernels/impl/complex_kernel_impl.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(conj, + CPU, + ALL_LAYOUT, + pten::ConjKernel, + paddle::platform::complex, + paddle::platform::complex, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/dot_grad_kernel.cc b/paddle/pten/kernels/cpu/dot_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9d5c35e134c8310be69f7c8af9fcb6691624d55 --- /dev/null +++ b/paddle/pten/kernels/cpu/dot_grad_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(dot_grad, + CPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/dot_kernel.cc b/paddle/pten/kernels/cpu/dot_kernel.cc index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..72e9e28907f909261288a264846d9b854185aeb2 100644 --- a/paddle/pten/kernels/cpu/dot_kernel.cc +++ b/paddle/pten/kernels/cpu/dot_kernel.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/dot_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/complex.h" + +namespace pten { + +template +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + auto const *x_ptr = x.data(), *x_ptr_ = &x_ptr[0]; + auto const *y_ptr = y.data(), *y_ptr_ = &y_ptr[0]; + auto* z = out->mutable_data(); + + // Loop over the total N elements of both operands while sum-reducing every + // B pairs along the way where B is the dimension of the least ordered axis + auto&& d = x.dims(); + auto const N = x.numel(); + auto const B = d[d.size() - 1]; + + for (int j = 0; j < N / B; j++) { + T ss = 0; + for (int i = 0; i < B; i++) ss += (*x_ptr_++) * (*y_ptr_++); + z[j] = ss; + } +} + +} // namespace pten + +using complex64 = ::paddle::platform::complex; +using complex128 = ::paddle::platform::complex; + +PT_REGISTER_CTX_KERNEL(dot, + CPU, + ALL_LAYOUT, + pten::DotKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cpu/matmul_grad_kernel.cc b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a8abb6701b0e0225d8354b8a8381f8a59ec1b23 --- /dev/null +++ b/paddle/pten/kernels/cpu/matmul_grad_kernel.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_grad_kernel.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul_grad, + CPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_double_grad, + CPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_triple_grad, + CPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/dot_grad_kernel.h b/paddle/pten/kernels/dot_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b0940e5b16a3354f930cc9c1c98020c82fe998c1 --- /dev/null +++ b/paddle/pten/kernels/dot_grad_kernel.h @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void DotGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + +template +void DotDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void DotTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + const DenseTensor& dout, + const DenseTensor& d_ddout, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_ddx, + DenseTensor* d_ddy, + DenseTensor* d_dout); + +} // namespace pten diff --git a/paddle/pten/kernels/dot_kernel.h b/paddle/pten/kernels/dot_kernel.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5ef660265333e9b579f3ed9faa8a7d99defbfd69 100644 --- a/paddle/pten/kernels/dot_kernel.h +++ b/paddle/pten/kernels/dot_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/kernels/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..2dd55a13e38e5472b3da8ca118a19946b23205c7 100644 --- a/paddle/pten/kernels/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include "paddle/pten/kernels/empty_kernel.h" + +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +namespace pten { + +template +void EmptyKernel(const Context& dev_ctx, + const ScalarArray& shape, + DenseTensor* out) { + out->Resize(paddle::framework::make_ddim(shape.GetData())); +} + +template +void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { + out->mutable_data(); +} + +} // namespace pten + +PT_REGISTER_CTX_KERNEL(empty, + CPU, + ALL_LAYOUT, + pten::EmptyKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(empty_like, + CPU, + ALL_LAYOUT, + pten::EmptyLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_CTX_KERNEL(empty, + GPU, + ALL_LAYOUT, + pten::EmptyKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(empty_like, + GPU, + ALL_LAYOUT, + pten::EmptyLikeKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} +#endif diff --git a/paddle/pten/kernels/empty_kernel.h b/paddle/pten/kernels/empty_kernel.h index d71ee0b1266f2d5ab3989fe57f4eb5dff7d5cf39..d283ef5c1e41ef64a2f4a38b382595fa2db3bd90 100644 --- a/paddle/pten/kernels/empty_kernel.h +++ b/paddle/pten/kernels/empty_kernel.h @@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) { return dense_out; } +template +DenseTensor Empty(const Context& dev_ctx) { + return Empty(dev_ctx, + {paddle::experimental::CppTypeToDataType::Type(), + {-1}, + DataLayout::NCHW}); +} + template DenseTensor Empty(const Context& dev_ctx, const ScalarArray& shape, diff --git a/paddle/pten/kernels/gpu/complex_kernel.cu b/paddle/pten/kernels/gpu/complex_kernel.cu index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..02f050f5bc838b73128a734c84104df24a39fa42 100644 --- a/paddle/pten/kernels/gpu/complex_kernel.cu +++ b/paddle/pten/kernels/gpu/complex_kernel.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/complex_kernel.h" +#include "paddle/pten/kernels/impl/complex_kernel_impl.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(conj, + GPU, + ALL_LAYOUT, + pten::ConjKernel, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/dot_grad_kernel.cu b/paddle/pten/kernels/gpu/dot_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..42af96f7c7265df306ffdb7493c96522376e842e --- /dev/null +++ b/paddle/pten/kernels/gpu/dot_grad_kernel.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/dot_grad_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/fluid/platform/complex.h" + +PT_REGISTER_CTX_KERNEL(dot_grad, + GPU, + ALL_LAYOUT, + pten::DotGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/gpu/dot_kernel.cu b/paddle/pten/kernels/gpu/dot_kernel.cu index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..1f9e7aa3f1cfd54e797cc63063a225389fba7bd2 100644 --- a/paddle/pten/kernels/gpu/dot_kernel.cu +++ b/paddle/pten/kernels/gpu/dot_kernel.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/dot_kernel.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/platform/complex.h" + +namespace pten { + +template +void DotKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + out->mutable_data(); + if (1 == out->dims().size()) { + auto eigen_out = pten::EigenScalar::From(*out); + auto eigen_x = pten::EigenVector::Flatten(x); + auto eigen_y = pten::EigenVector::Flatten(y); + + auto& dev = *dev_ctx.eigen_device(); + eigen_out.device(dev) = (eigen_x * eigen_y).sum(); + } else { + auto eigen_out = pten::EigenMatrix::From(*out); + auto eigen_x = pten::EigenMatrix::From(x); + auto eigen_y = pten::EigenMatrix::From(y); + + auto& dev = *dev_ctx.eigen_device(); + eigen_out.device(dev) = (eigen_x * eigen_y).sum(Eigen::DSizes(1)); + } +} + +} // namespace pten + +using complex64 = ::paddle::platform::complex; +using complex128 = ::paddle::platform::complex; + +PT_REGISTER_CTX_KERNEL(dot, + GPU, + ALL_LAYOUT, + pten::DotKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/gpu/matmul_grad_kernel.cu b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f20c3f82c9262356fd483b80de33d9f6f332a597 --- /dev/null +++ b/paddle/pten/kernels/gpu/matmul_grad_kernel.cu @@ -0,0 +1,50 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/matmul_grad_kernel.h" + +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(matmul_grad, + GPU, + ALL_LAYOUT, + pten::MatmulGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_double_grad, + GPU, + ALL_LAYOUT, + pten::MatmulDoubleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_CTX_KERNEL(matmul_triple_grad, + GPU, + ALL_LAYOUT, + pten::MatmulTripleGradKernel, + float, + double, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/hybird/transpose.h b/paddle/pten/kernels/hybird/transpose.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..17f52c74a13441ce4b856629fd40de4aa613323e 100644 --- a/paddle/pten/kernels/hybird/transpose.h +++ b/paddle/pten/kernels/hybird/transpose.h @@ -0,0 +1,62 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/pten/core/dense_tensor.h" + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + +namespace pten { + +namespace math { + +template +struct TransposeNormal { + // for dims >= 7 situation + void operator()(const DeviceContext& dev_ctx, + const pten::DenseTensor& in, + pten::DenseTensor* out, + const std::vector& axis); +}; + +template +struct Transpose { + void operator()(const DeviceContext& dev_ctx, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { + permute[i] = axis[i]; + } + auto eigen_in = pten::EigenTensor::From(in); + auto eigen_out = pten::EigenTensor::From(*out); + auto* dev = dev_ctx.eigen_device(); + // use 32bit index to speed up computation + bool use_32bit_index = eigen_out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace()); + if (use_32bit_index && is_gpu_place) { + To32BitIndex(eigen_out).device(*dev) = + To32BitIndex(eigen_in).shuffle(permute); + } else { + eigen_out.device(*dev) = eigen_in.shuffle(permute); + } + } +}; + +} // namespace math +} // namespace pten diff --git a/paddle/pten/kernels/impl/complex_kernel_impl.h b/paddle/pten/kernels/impl/complex_kernel_impl.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..e0c6825a78a53c05a69a368156caeebffbe8dea6 100644 --- a/paddle/pten/kernels/impl/complex_kernel_impl.h +++ b/paddle/pten/kernels/impl/complex_kernel_impl.h @@ -0,0 +1,36 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace pten { + +template +void ConjKernel(const Context& context, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = out->mutable_data(); + + paddle::platform::ForRange for_range(context, numel); + paddle::operators::math::ConjFunctor functor(x_data, numel, out_data); + for_range(functor); +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/dot_grad_kernel_impl.h b/paddle/pten/kernels/impl/dot_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..16c87bbab474abe5c0d26394e6b5476d8e9dfa14 --- /dev/null +++ b/paddle/pten/kernels/impl/dot_grad_kernel_impl.h @@ -0,0 +1,919 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/hybird/eigen/common.h" + +#include "paddle/pten/kernels/complex_kernel.h" + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +namespace pten { + +template +struct DotGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy); +}; + +template +struct DotGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto dout = EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = EigenVector::Flatten(*tensor_y); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + + ConjKernel(ctx, *tensor_y, tensor_dx); + + auto dx = EigenVector::Flatten(*tensor_dx); + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = EigenVector::Flatten(*tensor_x); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + + ConjKernel(ctx, *tensor_x, tensor_dy); + + auto dy = EigenVector::Flatten(*tensor_dy); + dy.device(dev) = dy * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(); + auto y = EigenMatrix::From(*tensor_y); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + + ConjKernel(ctx, *tensor_y, tensor_dx); + + auto dx = EigenMatrix::From(*tensor_dx); + dx.device(dev) = dx * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto x = EigenMatrix::From(*tensor_x); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + + ConjKernel(ctx, *tensor_x, tensor_dy); + + auto dy = EigenMatrix::From(*tensor_dy); + dy.device(dev) = dy * dout.broadcast(size); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_y = tensor_y->data(); + const DDim& dim = tensor_x->dims(); + size_t N = static_cast(paddle::framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_x = tensor_x->data(); + const DDim& dim = tensor_y->dims(); + size_t N = static_cast(paddle::framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s]; + } + } +#endif + } +}; + +template +struct DotGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto dout = EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + auto y = EigenVector::Flatten(*tensor_y); + auto dx = EigenVector::Flatten(*tensor_dx); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = EigenVector::Flatten(*tensor_x); + auto dy = EigenVector::Flatten(*tensor_dy); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = *ctx.eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } +#else + auto const *x = tensor_x->data(), *y = tensor_y->data(), + *dz = tensor_dout->data(); + auto&& d = tensor_x->dims(); + auto const N = tensor_x->numel(); + auto const B = d[d.size() - 1]; + + if (tensor_dx) { + auto* dx = tensor_dx->mutable_data(); + for (auto j = 0; j < N / B; ++j) { + auto const ss = dz[j]; + for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss; + } + } + + if (tensor_dy) { + auto* dy = tensor_dy->mutable_data(); + for (auto j = 0; j < N / B; ++j) { + auto const ss = dz[j]; + for (auto i = 0; i < B; i++) *dy++ = *x++ * ss; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout); +}; + +template +struct DotDoubleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + DenseTensor tensor_dout_help; + auto& dev = *ctx.eigen_device(); + if (tensor_dx || tensor_dy) { + tensor_dout_help = Conj(ctx, *tensor_dout); + } + if (tensor_dx) { + auto ddy = EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = EigenVector::Flatten(*tensor_dx); + auto dout = EigenVector::Flatten(tensor_dout_help); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + auto ddx = EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + auto dy = EigenVector::Flatten(*tensor_dy); + auto dout = EigenVector::Flatten(tensor_dout_help); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + DenseTensor tensor_x_help = Conj(ctx, *tensor_x); + DenseTensor tensor_y_help = Conj(ctx, *tensor_y); + + auto x = EigenVector::Flatten(tensor_x_help); + auto y = EigenVector::Flatten(tensor_y_help); + auto ddx = EigenVector::Flatten(*tensor_ddx); + auto ddy = EigenVector::Flatten(*tensor_ddy); + auto ddout = EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_ddy = tensor_ddy->data(); + const DDim& dim = tensor_dx->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_ddx = tensor_ddx->data(); + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } else { + data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] + + T(data_y[i].real, -data_y[i].imag) * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotDoubleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* tensor_x, + const DenseTensor* tensor_y, + const DenseTensor* tensor_dout, + const DenseTensor* tensor_ddx, + const DenseTensor* tensor_ddy, + DenseTensor* tensor_dx, + DenseTensor* tensor_dy, + DenseTensor* tensor_ddout) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == tensor_dout->dims().size()) { + auto& dev = *ctx.eigen_device(); + auto dout = EigenVector::Flatten(*tensor_dout); + if (tensor_dx) { + tensor_dx->mutable_data(); + auto ddy = EigenVector::Flatten(*tensor_ddy); + Eigen::DSizes size(tensor_ddy->numel()); + auto dx = EigenVector::Flatten(*tensor_dx); + dx.device(dev) = ddy * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(); + auto ddx = EigenVector::Flatten(*tensor_ddx); + Eigen::DSizes size(tensor_ddx->numel()); + + auto dy = EigenVector::Flatten(*tensor_dy); + dy.device(dev) = ddx * dout.broadcast(size); + } + + if (tensor_ddout) { + tensor_ddout->mutable_data(); + auto x = EigenVector::Flatten(*tensor_x); + auto y = EigenVector::Flatten(*tensor_y); + auto ddx = EigenVector::Flatten(*tensor_ddx); + auto ddy = EigenVector::Flatten(*tensor_ddy); + auto ddout = EigenVector::Flatten(*tensor_ddout); + ddout.device(dev) = (x * ddy + y * ddx).sum(); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(); + const auto* data_ddy = tensor_ddy->data(); + const DDim& dim = tensor_dx->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_dout[s] * data_ddy[i]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(); + const auto* data_ddx = tensor_ddx->data(); + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_dout[s] * data_ddx[i]; + } + } + + if (tensor_ddout) { + auto* data_ddout = tensor_ddout->mutable_data(); + auto* data_x = tensor_x->data(); + auto* data_y = tensor_y->data(); + auto* data_ddx = tensor_ddx->data(); + auto* data_ddy = tensor_ddy->data(); + + const DDim& dim = tensor_dy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } else { + data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i]; + } + new_s = false; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy); +}; + +// TODO(wuweilong): enable this function when the unittests framewark for multi +// grad is ok (dtype: complex64 or complex128). +template +struct DotTripleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + DenseTensor in_tensor_d_ddout_help; + auto& dev = *ctx.eigen_device(); + if (out_tensor_d_x || out_tensor_d_y) { + in_tensor_d_ddout_help = + Conj(ctx, *in_tensor_d_ddout); + } + if (out_tensor_d_x) { + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = EigenVector::Flatten(*out_tensor_d_x); + auto d_ddout = EigenVector::Flatten(in_tensor_d_ddout_help); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + auto d_y = EigenVector::Flatten(*out_tensor_d_y); + auto d_ddout = EigenVector::Flatten(in_tensor_d_ddout_help); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + DenseTensor in_tensor_ddx_help = + Conj(ctx, *in_tensor_ddx); + DenseTensor in_tensor_ddy_help = + Conj(ctx, *in_tensor_ddy); + + auto ddx = EigenVector::Flatten(in_tensor_ddx_help); + auto ddy = EigenVector::Flatten(in_tensor_ddy_help); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + DenseTensor in_tensor_dout_help = + Conj(ctx, *in_tensor_dout); + DenseTensor in_tensor_y_help = + Conj(ctx, *in_tensor_y); + + auto dout = EigenVector::Flatten(in_tensor_dout_help); + auto y = EigenVector::Flatten(in_tensor_y_help); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + DenseTensor in_tensor_dout_help = + Conj(ctx, *in_tensor_dout); + DenseTensor in_tensor_x_help = + Conj(ctx, *in_tensor_x); + + auto dout = EigenVector::Flatten(in_tensor_dout_help); + auto x = EigenVector::Flatten(in_tensor_x_help); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(); + const auto* data_ddy = in_tensor_ddy->data(); + + const DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(); + const auto* data_ddx = in_tensor_ddx->data(); + + const DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const DDim& dim = out_tensor_d_dout->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } else { + data_d_dout[s] += + T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] + + T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] + + T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] + + T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s]; + } + } +#endif + } +}; + +template +struct DotTripleGradFunction> { + void operator()(const DeviceContext& ctx, + const DenseTensor* in_tensor_x, + const DenseTensor* in_tensor_y, + const DenseTensor* in_tensor_ddx, + const DenseTensor* in_tensor_ddy, + const DenseTensor* in_tensor_d_dx, + const DenseTensor* in_tensor_d_dy, + const DenseTensor* in_tensor_dout, + const DenseTensor* in_tensor_d_ddout, + DenseTensor* out_tensor_d_x, + DenseTensor* out_tensor_d_y, + DenseTensor* out_tensor_d_dout, + DenseTensor* out_tensor_d_ddx, + DenseTensor* out_tensor_d_ddy) { +#if defined(__NVCC__) || defined(__HIPCC__) + if (1 == in_tensor_d_ddout->dims().size()) { + auto& dev = *ctx.eigen_device(); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + if (out_tensor_d_x) { + out_tensor_d_x->mutable_data(); + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + Eigen::DSizes size(in_tensor_ddy->numel()); + auto d_x = EigenVector::Flatten(*out_tensor_d_x); + d_x.device(dev) = ddy * d_ddout.broadcast(size); + } + + if (out_tensor_d_y) { + out_tensor_d_y->mutable_data(); + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + Eigen::DSizes size(in_tensor_ddx->numel()); + + auto d_y = EigenVector::Flatten(*out_tensor_d_y); + d_y.device(dev) = ddx * d_ddout.broadcast(size); + } + + if (out_tensor_d_dout) { + out_tensor_d_dout->mutable_data(); + auto ddx = EigenVector::Flatten(*in_tensor_ddx); + auto ddy = EigenVector::Flatten(*in_tensor_ddy); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_dout = EigenVector::Flatten(*out_tensor_d_dout); + d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum(); + } + + if (out_tensor_d_ddx) { + out_tensor_d_ddx->mutable_data(); + auto dout = EigenVector::Flatten(*in_tensor_dout); + auto y = EigenVector::Flatten(*in_tensor_y); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dy = EigenVector::Flatten(*in_tensor_d_dy); + auto d_ddx = EigenVector::Flatten(*out_tensor_d_ddx); + Eigen::DSizes size(in_tensor_y->numel()); + d_ddx.device(dev) = + (dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size)); + } + + if (out_tensor_d_ddy) { + out_tensor_d_ddy->mutable_data(); + auto dout = EigenVector::Flatten(*in_tensor_dout); + auto x = EigenVector::Flatten(*in_tensor_x); + auto d_ddout = EigenVector::Flatten(*in_tensor_d_ddout); + auto d_dx = EigenVector::Flatten(*in_tensor_d_dx); + auto d_ddy = EigenVector::Flatten(*out_tensor_d_ddy); + Eigen::DSizes size(in_tensor_x->numel()); + d_ddy.device(dev) = + (dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size)); + } + } +#else + const auto* data_d_ddout = in_tensor_d_ddout->data(); + + if (out_tensor_d_x) { + auto* data_d_x = out_tensor_d_x->mutable_data(); + const auto* data_ddy = in_tensor_ddy->data(); + + const DDim& dim = out_tensor_d_x->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_x[i] = data_ddy[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_y) { + auto* data_d_y = out_tensor_d_y->mutable_data(); + const auto* data_ddx = in_tensor_ddx->data(); + + const DDim& dim = out_tensor_d_y->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_y[i] = data_ddx[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_dout) { + auto* data_d_dout = out_tensor_d_dout->mutable_data(); + auto* data_ddx = in_tensor_ddx->data(); + auto* data_ddy = in_tensor_ddy->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + + const DDim& dim = in_tensor_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + bool new_s = false; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) { + ++s; + new_s = true; + } + if (new_s) { + data_d_dout[s] = + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } else { + data_d_dout[s] += + data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i]; + } + new_s = false; + } + } + + if (out_tensor_d_ddx) { + auto* data_d_ddx = out_tensor_d_ddx->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dy = in_tensor_d_dy->data(); + auto* data_y = in_tensor_y->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddx->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddx[i] = + data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s]; + } + } + + if (out_tensor_d_ddy) { + auto* data_d_ddy = out_tensor_d_ddy->mutable_data(); + auto* data_dout = in_tensor_dout->data(); + auto* data_d_dx = in_tensor_d_dx->data(); + auto* data_x = in_tensor_x->data(); + auto* data_d_ddout = in_tensor_d_ddout->data(); + + const DDim& dim = out_tensor_d_ddy->dims(); + size_t N = static_cast(product(dim)); + auto step = dim[dim.size() - 1]; + int s = -1; + + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_d_ddy[i] = + data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s]; + } + } +#endif + } +}; + +template +void DotGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + if (dx) { + dx->mutable_data(); + } + if (dy) { + dy->mutable_data(); + } + DotGradFunction()(dev_ctx, &x, &y, &dout, dx, dy); +} + +template +void DotDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + if (dx) { + dx->mutable_data(); + } + if (dy) { + dy->mutable_data(); + } + if (ddout) { + ddout->mutable_data(); + } + DotDoubleGradFunction()( + dev_ctx, &x, &y, &dout, ddx, ddy, dx, dy, ddout); +} + +template +void DotTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + const DenseTensor& dout, + const DenseTensor& d_ddout, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_ddx, + DenseTensor* d_ddy, + DenseTensor* d_dout) { + if (d_x) { + d_x->mutable_data(); + } + if (d_y) { + d_y->mutable_data(); + } + if (d_ddx) { + d_ddx->mutable_data(); + } + if (d_ddy) { + d_ddy->mutable_data(); + } + if (d_dout) { + d_dout->mutable_data(); + } + + DotTripleGradFunction()(dev_ctx, + &x, + &y, + ddx, + ddy, + d_dx, + d_dy, + dout, + d_ddout, + d_x, + d_y, + d_dout, + d_ddx, + d_ddy); +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..802cc019d78c5178c4c28b07586aa0ee8fbcd276 --- /dev/null +++ b/paddle/pten/kernels/impl/matmul_grad_kernel_impl.h @@ -0,0 +1,1742 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// #include "paddle/pten/kernels/complex_kernel.h" +#include "paddle/pten/include/math.h" +#include "paddle/pten/kernels/empty_kernel.h" +#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h" +#include "paddle/pten/kernels/impl/matmul_kernel_impl.h" + +#include "paddle/pten/kernels/cpu/reduce.h" +#include "paddle/pten/kernels/funcs/reduce_functor.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/backends/gpu/gpu_context.h" + +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/pten/kernels/gpu/reduce.h" +#endif + +namespace pten { + +template +struct ReduceSumForMatmulGrad { + void operator()(const Context& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims); +}; + +template +struct ReduceSumForMatmulGrad { + void operator()(const CPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims) { + std::vector reduce_dims_tmp(reduce_dims.begin(), + reduce_dims.end()); + ReduceKernelImpl( + dev_ctx, input, output, reduce_dims_tmp, true, false); + } +}; + +#if defined(__NVCC__) || defined(__HIPCC__) +template +struct ReduceSumForMatmulGrad { + void operator()(const GPUContext& dev_ctx, + const DenseTensor& input, + DenseTensor* output, + const std::vector& reduce_dims) { + auto stream = dev_ctx.stream(); + kernels:: + TensorReduceFunctorImpl>( + input, output, kps::IdentityFunctor(), reduce_dims, stream); + } +}; +#endif + +// Reshape a rank-3 tensor from P x M x N to (P * M) x N. +// Identity op if the tensor is not of rank 3. +static DenseTensor FoldInitDims(const DenseTensor& input) { + DenseTensor output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx, + const DenseTensor& input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + DenseTensor output = EmptyLike(dev_ctx, input); + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + std::vector axis = {1, 0, 2}; + math::Transpose trans; + trans(dev_ctx, input, &output, axis); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + return output; +} + +template +void MatMul(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + const DenseTensor& b, + bool trans_b, + DenseTensor* out, + bool flag = false) { + out->mutable_data(); + auto blas = paddle::operators::math::GetBlas(dev_ctx); + auto mat_dim_a = + paddle::operators::math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = + paddle::operators::math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + if (a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a.data(), + mat_dim_a, + b.data(), + mat_dim_b, + static_cast(1), + out->mutable_data(), + static_cast(flag)); +} + +/** + * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the + * original x_dim is returned. + */ +static DDim RowMatrixFromVector(const DDim& x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return paddle::framework::make_ddim({1, x_dim[0]}); +} + +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static DDim ColumnMatrixFromVector(const DDim& y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return paddle::framework::make_ddim({y_dim[0], 1}); +} + +/** + * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. + * + * The shape would be [BatchSize, H, W] or [H, W]. + * If transposed, `H,W` will be swapped. + */ +static void ReshapeTensorIntoMatrixSequence( + DenseTensor* x, const paddle::operators::math::MatDescriptor& descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} + +static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x, + DenseTensor* y, + DenseTensor* out, + bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = + paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = + paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, + mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +template +void CalcInputGrad(const Context& dev_ctx, + const DenseTensor& a, + bool trans_a, + bool is_fold_init_dims_a, + const DenseTensor& b, + bool trans_b, + bool is_fold_init_dims_b, + DenseTensor* out, + bool flag = false) { + if (out == nullptr) return; + bool need_combine = + (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; + if (!need_combine) { + MatMul(dev_ctx, a, trans_a, b, trans_b, out, flag); + } else { + MatMul( + dev_ctx, + is_fold_init_dims_a ? FoldInitDims(a) + : FoldHeadAndLastDims(dev_ctx, a), + trans_a, + is_fold_init_dims_b ? FoldInitDims(b) + : FoldHeadAndLastDims(dev_ctx, b), + trans_b, + out, + flag); + } +} + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy) { + // get dims + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(out_grad.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + if (dx) dx->mutable_data(); + if (dy) dy->mutable_data(); + if (out_grad.numel() == 1) { + DotGradFunction()(dev_ctx, &x, &y, &out_grad, dx, dy); + return; + } + } + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + // for complex + DenseTensor x_conj; + DenseTensor y_conj; + + // Case2: no broadcast or no batch size, it aims to speed and it is same as + // matmul in old version. + if (!is_broadcast) { + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor out_grad_help = out_grad; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &out_grad_help, transpose_x, transpose_y); + + DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x_help.dims()) { + dx->Resize(x_help.dims()); + } + + y_conj = Conj(dev_ctx, y_help); + } + + DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y_help.dims()) { + dy->Resize(y_help.dims()); + } + + x_conj = Conj(dev_ctx, x_help); + } + + if (transpose_x && transpose_y) { + CalcInputGrad( + dev_ctx, y_conj, true, true, out_grad_help, true, false, dx); + CalcInputGrad( + dev_ctx, out_grad_help, true, true, x_conj, true, false, dy); + } else if (transpose_x) { + CalcInputGrad( + dev_ctx, y_conj, false, false, out_grad_help, true, false, dx); + CalcInputGrad( + dev_ctx, x_conj, false, false, out_grad_help, false, true, dy); + } else if (transpose_y) { + CalcInputGrad( + dev_ctx, out_grad_help, false, false, y_conj, false, true, dx); + CalcInputGrad( + dev_ctx, out_grad_help, true, true, x_conj, false, true, dy); + } else { + CalcInputGrad( + dev_ctx, out_grad_help, false, false, y_conj, true, false, dx); + CalcInputGrad( + dev_ctx, x_conj, true, true, out_grad_help, false, true, dy); + } + + if (dx) { + if (dx_dims != x_help.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y_help.dims()) { + dy->Resize(dy_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + + DenseTensor dx_help = Empty(dev_ctx); + DenseTensor dy_help = Empty(dev_ctx); + + if (transpose_x) { + if (transpose_y) { + // X'Y': dA = Y'G', dB = G'X' + if (dx) + MatMulFunction(dev_ctx, + y_conj, + out_grad, + y_dims, + dout_dims, + &dx_help, + true, + true); + if (dy) + MatMulFunction(dev_ctx, + out_grad, + x_conj, + dout_dims, + x_dims, + &dy_help, + true, + true); + } else { + // X'Y: dX = YG', dY = XG + if (dx) + MatMulFunction(dev_ctx, + y_conj, + out_grad, + y_dims, + dout_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + x_conj, + out_grad, + x_dims, + dout_dims, + &dy_help, + false, + false); + } + } else { + if (transpose_y) { + // XY': dX = GY, dY = G'X + if (dx) + MatMulFunction(dev_ctx, + out_grad, + y_conj, + dout_dims, + y_dims, + &dx_help, + false, + false); + if (dy) + MatMulFunction(dev_ctx, + out_grad, + x_conj, + dout_dims, + x_dims, + &dy_help, + true, + false); + } else { + // XY: dX = GY', dY = X'G + if (dx) + MatMulFunction(dev_ctx, + out_grad, + y_conj, + dout_dims, + y_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + x_conj, + out_grad, + x_dims, + dout_dims, + &dy_help, + true, + false); + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims); + } + dy->Resize(y.dims()); + } + // Get the OutputGrad(out) + } +} + +template +void MatmulDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's or y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + DotDoubleGradFunction()( + dev_ctx, &x, &y, &dout, ddx.get_ptr(), ddy.get_ptr(), dx, dy, ddout); + return; + } + + DenseTensor x_conj; + DenseTensor y_conj; + DenseTensor dout_conj; + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor dout_help = dout; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &dout_help, transpose_x, transpose_y); + DDim dx_dims; + + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x_help.dims()) { + dx->Resize(x_help.dims()); + } + } + + DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y_help.dims()) { + dy->Resize(y_help.dims()); + } + } + + DDim ddout_dims; + if (ddout) { + ddout_dims = ddout->dims(); + if (ddout_dims != dout_help.dims()) { + ddout->Resize(dout_help.dims()); + } + + x_conj = Conj(dev_ctx, x_help); + y_conj = Conj(dev_ctx, y_help); + } + + if (dx || dy) { + dout_conj = Conj(dev_ctx, dout_help); + } + + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = ddx.get(); + if (ddx_mat.dims() != x_help.dims()) { + ddx_mat.Resize(x_help.dims()); + } + if (dy) { + if (transpose_x && transpose_y) { + // dy = dout' * ddx' + CalcInputGrad( + dev_ctx, dout_conj, true, true, ddx_mat, true, false, dy, false); + } else if (transpose_x) { + // dy = ddx * dout + CalcInputGrad(dev_ctx, + ddx_mat, + false, + false, + dout_conj, + false, + true, + dy, + false); + } else if (transpose_y) { + // dy = dout' * ddx + CalcInputGrad( + dev_ctx, dout_conj, true, true, ddx_mat, false, true, dy, false); + } else { + // dy = ddx' * dout + CalcInputGrad( + dev_ctx, ddx_mat, true, true, dout_conj, false, true, dy, false); + } + } + + if (ddout) { + CalcInputGrad(dev_ctx, + ddx_mat, + transpose_x, + true, + y_conj, + transpose_y, + false, + ddout, + ddout_flag); + ddout_flag = true; + } + } + + if (ddy) { + auto ddy_mat = ddy.get(); + if (ddy_mat.dims() != y_help.dims()) { + ddy_mat.Resize(y_help.dims()); + } + if (dx) { + if (transpose_x && transpose_y) { + // dx = ddy' * dout' + CalcInputGrad( + dev_ctx, ddy_mat, true, true, dout_conj, true, false, dx, false); + } else if (transpose_x) { + // dx = ddy * dout' + CalcInputGrad(dev_ctx, + ddy_mat, + false, + false, + dout_conj, + true, + false, + dx, + false); + } else if (transpose_y) { + // dx = dout * ddy + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + ddy_mat, + false, + true, + dx, + false); + } else { + // dx = dout * ddy' + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + ddy_mat, + true, + false, + dx, + false); + } + } + + if (ddout) { + CalcInputGrad(dev_ctx, + x_conj, + transpose_x, + true, + ddy_mat, + transpose_y, + false, + ddout, + ddout_flag); + } + } + + if (dx) { + if (dx_dims != x_help.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y_help.dims()) { + dy->Resize(dy_dims); + } + } + + if (ddout) { + if (ddout_dims != dout_help.dims()) { + ddout->Resize(ddout_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + if (dx || dy) { + dout_conj = Conj(dev_ctx, dout); + } + if (ddout) { + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + } + + DenseTensor dx_help = Empty(dev_ctx); + DenseTensor dy_help = Empty(dev_ctx); + + if (transpose_x) { + if (transpose_y) { + if (dx) { + MatMulFunction(dev_ctx, + ddy.get(), + dout_conj, + y_dims, + dout_dims, + &dx_help, + true, + true); + } + if (dy) { + MatMulFunction(dev_ctx, + dout_conj, + ddx.get(), + dout_dims, + x_dims, + &dy_help, + true, + true); + } + } else { + if (dx) + MatMulFunction(dev_ctx, + ddy.get(), + dout_conj, + y_dims, + dout_dims, + &dx_help, + false, + true); + if (dy) + MatMulFunction(dev_ctx, + ddx.get(), + dout_conj, + x_dims, + dout_dims, + &dy_help, + false, + false); + } + } else { + if (transpose_y) { + if (dx) { + MatMulFunction(dev_ctx, + dout_conj, + ddy.get(), + dout_dims, + y_dims, + &dx_help, + false, + false); + } + if (dy) { + MatMulFunction(dev_ctx, + dout_conj, + ddx.get(), + dout_dims, + x_dims, + &dy_help, + true, + false); + } + } else { + if (dx) { + MatMulFunction(dev_ctx, + dout_conj, + ddy.get(), + dout_dims, + y_dims, + &dx_help, + false, + true); + } + if (dy) { + MatMulFunction(dev_ctx, + ddx.get(), + dout_conj, + x_dims, + dout_dims, + &dy_help, + true, + false); + } + } + } + + // get help dims + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (dx) { + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dx_help, dx, dx_reduce_dims); + } + dx->Resize(x.dims()); + } + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, dy_help, dy, dy_reduce_dims); + } + dy->Resize(y.dims()); + } + + if (ddout) { + // Calculate the gradient of OutputGrad(Out) + MatMulFunction(dev_ctx, + ddx.get(), + y_conj, + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y); + MatMulFunction(dev_ctx, + x_conj, + ddy.get(), + x_dims, + y_dims, + ddout, + transpose_x, + transpose_y, + true); + } + } +} + +template +void MatmulTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& ddy, + paddle::optional d_dx, + paddle::optional d_dy, + paddle::optional d_ddout, + bool transpose_x, + bool transpose_y, + DenseTensor* out_d_x, + DenseTensor* out_d_y, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx, + DenseTensor* out_d_ddy) { + // Get dims from the input x, y, output_grad + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + // Case1 : x's and y's dim = 1 + if (x_ndim == 1 && y_ndim == 1) { + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1"; + DotTripleGradFunction()(dev_ctx, + &x, + &y, + &ddx, + &ddy, + d_dx.get_ptr(), + d_dy.get_ptr(), + &dout, + d_ddout.get_ptr(), + out_d_x, + out_d_y, + out_d_dout, + out_d_ddx, + out_d_ddy); + return; + } + + DenseTensor x_conj; + DenseTensor y_conj; + DenseTensor dout_conj; + DenseTensor ddx_conj; + DenseTensor ddy_conj; + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal( + x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); + } + + if (!is_broadcast) { + // Case2: no broadcast or no batch size + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2"; + DenseTensor x_help = x; + DenseTensor y_help = y; + DenseTensor dout_help = dout; + DenseTensor ddx_help = ddx; + DenseTensor ddy_help = ddy; + ReshapeXYOutIntoMatrixSequence( + &x_help, &y_help, &dout_help, transpose_x, transpose_y); + + if (ddx_help.dims() != x_help.dims()) { + ddx_help.Resize(x_help.dims()); + } + + if (ddy_help.dims() != y_help.dims()) { + ddy_help.Resize(y_help.dims()); + } + + DDim out_dx_dims; + if (out_d_x) { + out_dx_dims = out_d_x->dims(); + if (out_dx_dims != x_help.dims()) { + out_d_x->Resize(x_help.dims()); + } + } + + DDim out_dy_dims; + if (out_d_y) { + out_dy_dims = out_d_y->dims(); + if (out_dy_dims != y_help.dims()) { + out_d_y->Resize(y_help.dims()); + } + } + + DDim out_d_dout_dims; + if (out_d_dout) { + out_d_dout_dims = out_d_dout->dims(); + if (out_d_dout_dims != dout_help.dims()) { + out_d_dout->Resize(dout_help.dims()); + } + + ddx_conj = Conj(dev_ctx, ddx_help); + ddy_conj = Conj(dev_ctx, ddy_help); + } + + DDim out_d_ddx_dims; + if (out_d_ddx) { + out_d_ddx_dims = out_d_ddx->dims(); + if (out_d_ddx_dims != x_help.dims()) { + out_d_ddx->Resize(x_help.dims()); + } + } + + DDim out_d_ddy_dims; + if (out_d_ddy) { + out_d_ddy_dims = out_d_ddy->dims(); + if (out_d_ddy_dims != y_help.dims()) { + out_d_ddy->Resize(y_help.dims()); + } + } + + if (out_d_ddx || out_d_ddy) { + x_conj = Conj(dev_ctx, x_help); + y_conj = Conj(dev_ctx, y_help); + dout_conj = Conj(dev_ctx, dout_help); + } + + bool d_dout_flag = false; + bool d_ddx_flag = false; + bool d_ddy_flag = false; + + if (d_ddout) { + auto d_ddout_mat = d_ddout.get(); + if (d_ddout_mat.dims() != dout_help.dims()) { + d_ddout_mat.Resize(dout_help.dims()); + } + + if (out_d_y) { + if (transpose_x && transpose_y) { + // out_d_y = d_ddout' * ddx' + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + ddx_conj, + true, + false, + out_d_y, + false); + } else if (transpose_x) { + // out_d_y = ddx * d_ddout + CalcInputGrad(dev_ctx, + ddx_conj, + false, + false, + d_ddout_mat, + false, + true, + out_d_y, + false); + } else if (transpose_y) { + // out_d_y = d_ddout' * ddx + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + ddx_conj, + false, + true, + out_d_y, + false); + } else { + // out_d_y = ddx' * d_ddout + CalcInputGrad(dev_ctx, + ddx_conj, + true, + true, + d_ddout_mat, + false, + true, + out_d_y, + false); + } + } + if (out_d_x) { + if (transpose_x && transpose_y) { + // out_d_x = ddy' * d_ddout' + CalcInputGrad(dev_ctx, + ddy_conj, + true, + true, + d_ddout_mat, + true, + false, + out_d_x, + false); + } else if (transpose_x) { + // out_d_x = ddy * d_ddout' + CalcInputGrad(dev_ctx, + ddy_conj, + false, + false, + d_ddout_mat, + true, + false, + out_d_x, + false); + } else if (transpose_y) { + // out_d_x = d_ddout * ddy + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + ddy_conj, + false, + true, + out_d_x, + false); + } else { + // out_d_x = d_ddout * ddy' + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + ddy_conj, + true, + false, + out_d_x, + false); + } + } + + // equations: + // d_ddx = DOut * D_DY + Y * D_DDOut + // Let: d_ddx1 = Y * D_DDOut + // Let: d_ddx2 = DOut * D_DY + + // d_ddy = DOut * D_DX + X * D_DDOut + // Let: d_ddy1 = X * D_DDOut + // Let: d_ddy2 = DOut * D_DX + + // d_dout = DDY * D_DX + DDX * D_DY + // Let: d_dout1 = DDX * D_DY + // Let: d_dout2 = DDY * D_DX + + // compute d_ddx1 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + CalcInputGrad(dev_ctx, + y_conj, + true, + true, + d_ddout_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + CalcInputGrad(dev_ctx, + y_conj, + false, + false, + d_ddout_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + y_conj, + false, + true, + out_d_ddx, + d_ddx_flag); + } else { + // out_d_ddx1 = d_ddout * y' + CalcInputGrad(dev_ctx, + d_ddout_mat, + false, + false, + y_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } + d_ddx_flag = true; + } + + // compute d_ddy1 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + x_conj, + true, + false, + out_d_ddy, + false); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + CalcInputGrad(dev_ctx, + x_conj, + false, + false, + d_ddout_mat, + false, + true, + out_d_ddy, + false); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + CalcInputGrad(dev_ctx, + d_ddout_mat, + true, + true, + x_conj, + false, + true, + out_d_ddy, + false); + } else { + // out_d_ddy1 = x' * d_ddout + CalcInputGrad(dev_ctx, + x_conj, + true, + true, + d_ddout_mat, + false, + true, + out_d_ddy, + false); + } + d_ddy_flag = true; + } + } + + if (d_dy) { + auto d_dy_mat = d_dy.get(); + if (d_dy_mat.dims() != y_help.dims()) { + d_dy_mat.Resize(y_help.dims()); + } + + // compute d_dout1 + if (out_d_dout) { + CalcInputGrad(dev_ctx, + ddx_conj, + transpose_x, + true, + d_dy_mat, + transpose_y, + false, + out_d_dout, + d_dout_flag); + d_dout_flag = true; + } + + // compute d_ddx2 + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx2 = D_DY' * DOut' + CalcInputGrad(dev_ctx, + d_dy_mat, + true, + true, + dout_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_x) { + // out_d_ddx2 = D_DY * Dout' + CalcInputGrad(dev_ctx, + d_dy_mat, + false, + false, + dout_conj, + true, + false, + out_d_ddx, + d_ddx_flag); + } else if (transpose_y) { + // out_d_ddx2 = Dout * D_DY + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + d_dy_mat, + false, + true, + out_d_ddx, + d_ddx_flag); + } else { + // out_d_ddx2 = Dout * D_DY' + CalcInputGrad(dev_ctx, + dout_conj, + false, + false, + d_dy_mat, + true, + false, + out_d_ddx, + d_ddx_flag); + } + } + } + + if (d_dx) { + auto d_dx_mat = d_dx.get(); + if (d_dx_mat.dims() != x_help.dims()) { + d_dx_mat.Resize(x_help.dims()); + } + + // compute d_dout2 + if (out_d_dout) { + CalcInputGrad(dev_ctx, + d_dx_mat, + transpose_x, + true, + ddy_conj, + transpose_y, + false, + out_d_dout, + d_dout_flag); + } + + // compute d_ddy2 + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy2 = dout' * d_dx' + CalcInputGrad(dev_ctx, + dout_conj, + true, + true, + d_dx_mat, + true, + false, + out_d_ddy, + d_ddy_flag); + } else if (transpose_x) { + // out_d_ddy2 = d_dx * dout + CalcInputGrad(dev_ctx, + d_dx_mat, + false, + false, + dout_conj, + false, + true, + out_d_ddy, + d_ddy_flag); + } else if (transpose_y) { + // out_d_ddy2 = dout' * d_dx + CalcInputGrad(dev_ctx, + dout_conj, + true, + true, + d_dx_mat, + false, + true, + out_d_ddy, + d_ddy_flag); + } else { + // out_d_ddy2 = d_dx' * dout + CalcInputGrad(dev_ctx, + d_dx_mat, + true, + true, + dout_conj, + false, + true, + out_d_ddy, + d_ddy_flag); + } + } + } + + if (out_d_x) { + if (out_dx_dims != x_help.dims()) { + out_d_x->Resize(out_dx_dims); + } + } + + if (out_d_y) { + if (out_dy_dims != y_help.dims()) { + out_d_y->Resize(out_dy_dims); + } + } + + if (out_d_dout) { + if (out_d_dout_dims != dout_help.dims()) { + out_d_dout->Resize(out_d_dout_dims); + } + } + + if (out_d_ddx) { + if (out_d_ddx_dims != x_help.dims()) { + out_d_ddx->Resize(out_d_ddx_dims); + } + } + + if (out_d_ddy) { + if (out_d_ddy_dims != y_help.dims()) { + out_d_ddy->Resize(out_d_ddy_dims); + } + } + } else { + // Case3: broadcast. It need cost much time to reduce sum for the + // broadcast and wastes the memory. + // So we should avoid the case in reality. + VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3"; + VLOG(3) << "It need cost much time to reduce sum for the broadcast and " + "wastes the memory. So we should avoid the case in reality"; + + DenseTensor out_dx_help = Empty(dev_ctx); + DenseTensor out_dy_help = Empty(dev_ctx); + DenseTensor out_d_ddx_help = Empty(dev_ctx); + DenseTensor out_d_ddy_help = Empty(dev_ctx); + + if (out_d_dout) { + ddx_conj = Conj(dev_ctx, ddx); + ddy_conj = Conj(dev_ctx, ddy); + } + if (out_d_ddx || out_d_ddy) { + x_conj = Conj(dev_ctx, x); + y_conj = Conj(dev_ctx, y); + dout_conj = Conj(dev_ctx, dout); + } + + if (transpose_x) { + if (transpose_y) { + // dX = ddY' d_ddout’, dY = d_ddout’ ddX' + if (out_d_x) + MatMulFunction(dev_ctx, + ddy_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_dx_help, + true, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddx_conj, + dout_dims, + x_dims, + &out_dy_help, + true, + true); + } else { + // dX = ddY d_ddout', dY = ddX d_ddout + if (out_d_x) + MatMulFunction(dev_ctx, + ddy_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_dx_help, + false, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + ddx_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_dy_help, + false, + false); + } + } else { + if (transpose_y) { + // dX = d_ddout ddY, dY = d_ddout’ ddX + if (out_d_x) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddy_conj, + dout_dims, + y_dims, + &out_dx_help, + false, + false); + if (out_d_y) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddx_conj, + dout_dims, + x_dims, + &out_dy_help, + true, + false); + } else { + // dX = d_ddout ddY', dY = ddX' d_ddout + if (out_d_x) + MatMulFunction(dev_ctx, + d_ddout.get(), + ddy_conj, + dout_dims, + y_dims, + &out_dx_help, + false, + true); + if (out_d_y) + MatMulFunction(dev_ctx, + ddx_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_dy_help, + true, + false); + } + } + + // get help dims + const std::vector dx_help_dims = + vectorize(out_dx_help.dims()); + const std::vector dy_help_dims = + vectorize(out_dx_help.dims()); + + std::vector dx_broadcast_dims(ndim); + std::vector dy_broadcast_dims(ndim); + + std::fill( + dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::fill( + dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(x_dims.data(), + x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + std::copy(y_dims.data(), + y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dx_reduce_dims; + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // Reduce sum to get grad by ReduceSum + if (out_d_x) { + if (dx_reduce_dims.empty()) { + *out_d_x = std::move(out_dx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_dx_help, out_d_x, dx_reduce_dims); + } + out_d_x->Resize(x.dims()); + } + + if (out_d_y) { + if (dy_reduce_dims.empty()) { + *out_d_y = std::move(out_dy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_dy_help, out_d_y, dy_reduce_dims); + } + out_d_y->Resize(y.dims()); + } + + // compute d_dout + if (out_d_dout) { + MatMulFunction(dev_ctx, + d_dx.get(), + ddy_conj, + x_dims, + y_dims, + out_d_dout, + transpose_x, + transpose_y); + MatMulFunction(dev_ctx, + ddx_conj, + d_dy.get(), + x_dims, + y_dims, + out_d_dout, + transpose_x, + transpose_y, + true); + } + // compute d_ddx + if (out_d_ddx) { + if (transpose_x && transpose_y) { + // out_d_ddx1 = y' * d_ddout' + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true); + // out_d_ddx2 = D_DY' * DOut' + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + true, + true, + true); + } else if (transpose_x) { + // out_d_ddx1 = y * d_ddout' + MatMulFunction(dev_ctx, + y_conj, + d_ddout.get(), + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true); + // out_d_ddx2 = D_DY * Dout' + MatMulFunction(dev_ctx, + d_dy.get(), + dout_conj, + y_dims, + dout_dims, + &out_d_ddx_help, + false, + true, + true); + } else if (transpose_y) { + // out_d_ddx1 = d_ddout * y + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false); + // out_d_ddx2 = Dout * D_DY + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + false, + true); + } else { + // out_d_ddx1 = d_ddout * y' + MatMulFunction(dev_ctx, + d_ddout.get(), + y_conj, + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true); + // out_d_ddx2 = Dout * D_DY' + MatMulFunction(dev_ctx, + dout_conj, + d_dy.get(), + dout_dims, + y_dims, + &out_d_ddx_help, + false, + true, + true); + } + if (dx_reduce_dims.empty()) { + *out_d_ddx = std::move(out_d_ddx_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_d_ddx_help, out_d_ddx, dx_reduce_dims); + } + out_d_ddx->Resize(x.dims()); + } + + // compute d_ddy + if (out_d_ddy) { + if (transpose_x && transpose_y) { + // out_d_ddy1 = d_ddout' * x' + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true); + // out_d_ddy2 = dout' * d_dx' + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + true, + true); + } else if (transpose_x) { + // out_d_ddy1 = x * d_ddout + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false); + // out_d_ddy2 = d_dx * dout + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + false, + false, + true); + } else if (transpose_y) { + // out_d_ddy1 = d_ddout' * x + MatMulFunction(dev_ctx, + d_ddout.get(), + x_conj, + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false); + // out_d_ddy2 = dout' * d_dx + MatMulFunction(dev_ctx, + dout_conj, + d_dx.get(), + dout_dims, + x_dims, + &out_d_ddy_help, + true, + false, + true); + } else { + // out_d_ddy1 = x' * d_ddout + MatMulFunction(dev_ctx, + x_conj, + d_ddout.get(), + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false); + // out_d_ddy2 = d_dx' * dout + MatMulFunction(dev_ctx, + d_dx.get(), + dout_conj, + x_dims, + dout_dims, + &out_d_ddy_help, + true, + false, + true); + } + + if (dy_reduce_dims.empty()) { + *out_d_ddy = std::move(out_d_ddy_help); + } else { + ReduceSumForMatmulGrad()( + dev_ctx, out_d_ddy_help, out_d_ddy, dy_reduce_dims); + } + out_d_ddy->Resize(y.dims()); + } + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/matmul_kernel_impl.h b/paddle/pten/kernels/impl/matmul_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f5f69f327a69f2218c0188a96015c85ea7e08f38 --- /dev/null +++ b/paddle/pten/kernels/impl/matmul_kernel_impl.h @@ -0,0 +1,509 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +static void GetBroadcastFromDims(const int x_ndim, + const std::int64_t* x_dims, + const int y_ndim, + const std::int64_t* y_dims, + std::int64_t* x_bd_dims, + std::int64_t* y_bd_dims, + std::int64_t* out_bd_dims) { + const int ndim = (std::max)(x_ndim, y_ndim); + std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); + std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); + std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); + std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); + + for (int i = 0; i < ndim; ++i) { + PADDLE_ENFORCE_EQ( + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, + true, + paddle::platform::errors::InvalidArgument( + "Input(X) and Input(Y) has error dim." + "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," + "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," + "But received X_broadcast's shape[%s] = [%s]" + "received Y_broadcast's shape[%s] = [%s]", + i, + i, + i, + i, + i, + x_bd_dims[i], + i, + y_bd_dims[i])); + if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { + out_bd_dims[i] = 0; + } else { + out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); + } + } +} + +static int64_t GetIndexMessage(const int n, + const int64_t* dims, + const int64_t* index) { + int64_t sum = 0; + for (int i = 0; i < n; ++i) { + if (dims[i] > 1) { + sum = sum * dims[i] + index[i]; + } + } + return sum; +} + +static void IndexIncreaseFromDims(const int ndim, + const int64_t* dims, + int64_t* index) { + for (int i = ndim - 1; i >= 0; --i) { + ++index[i]; + if (index[i] >= dims[i]) { + index[i] -= dims[i]; + } else { + break; + } + } +} + +template +void MatMulFunction(const Context& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false) { + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + + // Get data ptr + const T* x_data = X.data(); + const T* y_data = Y.data(); + + auto blas = paddle::operators::math::GetBlas(dev_ctx); + + if (x_ndim == 1 && y_ndim == 1) { + const int M = X.numel(); + const int N = Y.numel(); + PADDLE_ENFORCE_EQ( + M, + N, + paddle::platform::errors::InvalidArgument( + "X's numbers must be equal to Y's numbers," + "when X/Y's dims =1. But received X has [%d] elements," + "received Y has [%d] elements", + M, + N)); + VLOG(3) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(); + blas.GEMM(CblasNoTrans, + CblasTrans, + 1, + 1, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + return; + } + + if (x_ndim == 1) { + const int N = X.numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], + N, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + N, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], + N, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + N, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + Out->Resize(paddle::framework::make_ddim(out_dims)); + Out->mutable_data(); + if (trans_y) { + const int M = Y.numel() / N; + VLOG(3) << "MatMul's case 2"; + blas.GEMV(false, + M, + N, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = Y.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul's case 3"; + blas.GEMV(true, + N, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 4"; + blas.BatchedGEMM(CblasTrans, + CblasNoTrans, + M, + 1, + N, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data(), + batch_size, + M * N, + 0); + } + } + return; + } + + if (y_ndim == 1) { + const int N = Y.numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], + N, + paddle::platform::errors::InvalidArgument( + "Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 2, + N, + x_ndim - 2, + x_dims[x_ndim - 2])); + } else { + PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], + N, + paddle::platform::errors::InvalidArgument( + "Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 1, + N, + x_ndim - 1, + x_dims[x_ndim - 1])); + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + Out->Resize(paddle::framework::make_ddim(out_dims)); + Out->mutable_data(); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = X.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul's case 5"; + blas.GEMV(true, + N, + M, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 6"; + blas.BatchedGEMM(CblasTrans, + CblasNoTrans, + M, + 1, + N, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + batch_size, + M * N, + 0); + } + } else { + const int M = X.numel() / N; + VLOG(3) << "MatMul's case 7"; + blas.GEMV(false, + M, + N, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } + return; + } + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], + K, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + K, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], + K, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + K, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = (std::max)(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + + GetBroadcastFromDims(x_ndim - 2, + x_dims.data(), + y_ndim - 2, + y_dims.data(), + x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; + + Out->Resize(paddle::framework::make_ddim(out_broadcast_dims)); + Out->mutable_data(); + + const int batch_dim = ndim - 2; + // broadcast message + const bool is_broadcast_dims = + !std::equal(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + y_broadcast_dims.cbegin()); + + const std::int64_t x_batch_size = + std::accumulate(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t y_batch_size = + std::accumulate(y_broadcast_dims.cbegin(), + y_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t out_batch_size = + std::accumulate(out_broadcast_dims.cbegin(), + out_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + if (out_batch_size == 0) return; + if (x_batch_size == 1 && y_batch_size == 1) { + VLOG(3) << "MatMul's case 8"; + blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else if (x_batch_size == 1) { + if (M == 1 && trans_y) { + VLOG(3) << "MatMul's case 9"; + blas.GEMV(false, + y_batch_size * N, + K, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 10"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + 0, + K * N); + } + } else if (y_batch_size == 1) { + if (!trans_x) { + VLOG(3) << "MatMul's case 11"; + blas.GEMM(CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + x_batch_size * M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 12"; + blas.BatchedGEMM(CblasTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + M * K, + 0); + } + } else if (!is_broadcast_dims) { + VLOG(3) << "MatMul's case 13"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + M * K, + K * N); + } else { + // in the case, can't use stridedgemm + std::vector x_ptr(out_batch_size); + std::vector y_ptr(out_batch_size); + std::vector out_ptr(out_batch_size); + std::vector index(batch_dim, 0); + for (std::int64_t i = 0; i < out_batch_size; ++i) { + // using the index to get offset + const std::int64_t x_index = + GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); + const std::int64_t y_index = + GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); + + x_ptr[i] = x_data + x_index * M * K; + y_ptr[i] = y_data + y_index * K * N; + out_ptr[i] = Out->mutable_data() + i * M * N; + IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); + } + VLOG(3) << "MatMul's case 14"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_ptr.data(), + y_ptr.data(), + static_cast(flag), + out_ptr.data(), + out_batch_size); + } +} + +template +void MatMulFunction(const Context& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false) { + const std::vector x_dims = vectorize(X.dims()); + const std::vector y_dims = vectorize(Y.dims()); + MatMulFunction( + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); +} + +template +void MatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + MatMulFunction(dev_ctx, x, y, out, transpose_x, transpose_y); +} + +} // namespace pten diff --git a/paddle/pten/kernels/matmul_grad_kernel.h b/paddle/pten/kernels/matmul_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..db485b79d2736958e6bb6e549adc9a2c27780419 --- /dev/null +++ b/paddle/pten/kernels/matmul_grad_kernel.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace pten { + +template +void MatmulGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy); + +template +void MatmulDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + bool transpose_x, + bool transpose_y, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void MatmulTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + const DenseTensor& ddx, + const DenseTensor& ddy, + paddle::optional d_dx, + paddle::optional d_dy, + paddle::optional d_ddout, + bool transpose_x, + bool transpose_y, + DenseTensor* out_d_x, + DenseTensor* out_d_y, + DenseTensor* out_d_dout, + DenseTensor* out_d_ddx, + DenseTensor* out_d_ddy); + +} // namespace pten diff --git a/paddle/pten/kernels/matmul_kernel.h b/paddle/pten/kernels/matmul_kernel.h index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f9cb2c3801caa863b135ec2d4d188efff6dceb81 100644 --- a/paddle/pten/kernels/matmul_kernel.h +++ b/paddle/pten/kernels/matmul_kernel.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/infermeta/binary.h" + +#include "paddle/pten/kernels/empty_kernel.h" + +namespace pten { + +template +void MatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out); + +template +DenseTensor Matmul(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y) { + auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + MatmulKernel(dev_ctx, x, y, transpose_x, transpose_y, &dense_out); + return dense_out; +} + +} // namespace pten