From 24e02f6afbf60c409594d8cbc8c92fbe227c9fea Mon Sep 17 00:00:00 2001 From: Paddle CI_MAC Date: Thu, 2 Sep 2021 14:23:02 +0800 Subject: [PATCH] mirgate_35177 --- cmake/external/xpu.cmake | 128 ++++++++++-------- paddle/fluid/framework/operator.cc | 32 +++-- paddle/fluid/imperative/prepared_operator.cc | 69 +++++++--- paddle/fluid/operators/label_smooth_op_xpu.cc | 57 ++++++++ .../unittests/xpu/test_label_smooth_op_xpu.py | 64 +++++++++ 5 files changed, 269 insertions(+), 81 deletions(-) create mode 100644 paddle/fluid/operators/label_smooth_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index f846623602..2fc5bf7954 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -7,52 +7,74 @@ SET(XPU_PROJECT "extern_xpu") SET(XPU_API_LIB_NAME "libxpuapi.so") SET(XPU_RT_LIB_NAME "libxpurt.so") -if(NOT XPU_SDK_ROOT) - if (WITH_AARCH64) - SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/aarch64/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) - elseif(WITH_SUNWAY) - SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) - else() - SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09.tar.gz" CACHE STRING "" FORCE) - endif() - - SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") - SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") - SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") - SET(XPU_API_INC_DIR "${THIRD_PARTY_PATH}/install/xpu/include") - SET(XPU_LIB_DIR "${THIRD_PARTY_PATH}/install/xpu/lib") - - SET(XPU_API_LIB "${XPU_LIB_DIR}/${XPU_API_LIB_NAME}") - SET(XPU_RT_LIB "${XPU_LIB_DIR}/${XPU_RT_LIB_NAME}") - - SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${XPU_INSTALL_DIR}/lib") - - FILE(WRITE ${XPU_DOWNLOAD_DIR}/CMakeLists.txt - "PROJECT(XPU)\n" - "cmake_minimum_required(VERSION 3.0)\n" - "install(DIRECTORY xpu/include xpu/lib \n" - " DESTINATION ${XPU_INSTALL_DIR})\n") - - ExternalProject_Add( - ${XPU_PROJECT} - ${EXTERNAL_PROJECT_LOG_ARGS} - PREFIX ${XPU_SOURCE_DIR} - DOWNLOAD_DIR ${XPU_DOWNLOAD_DIR} - DOWNLOAD_COMMAND wget --no-check-certificate ${XPU_URL} -c -q -O xpu.tar.gz - && tar xvf xpu.tar.gz - DOWNLOAD_NO_PROGRESS 1 - UPDATE_COMMAND "" - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${XPU_INSTALL_ROOT} - CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${XPU_INSTALL_ROOT} - ) -else() - SET(XPU_API_INC_DIR "${XPU_SDK_ROOT}/XTDK/include/") - SET(XPU_API_LIB "${XPU_SDK_ROOT}/XTDK/shlib/libxpuapi.so") - SET(XPU_RT_LIB "${XPU_SDK_ROOT}/XTDK/runtime/shlib/libxpurt.so") - SET(XPU_LIB_DIR "${XPU_SDK_ROOT}/XTDK/shlib/") -endif() +IF(WITH_AARCH64) + SET(XPU_XRE_DIR_NAME "xre-kylin_aarch64") + SET(XPU_XDNN_DIR_NAME "xdnn-kylin_aarch64") + SET(XPU_XCCL_DIR_NAME "xccl-kylin_aarch64") +ELSEIF(WITH_SUNWAY) + SET(XPU_XRE_DIR_NAME "xre-deepin_sw6_64") + SET(XPU_XDNN_DIR_NAME "xdnn-deepin_sw6_64") + SET(XPU_XCCL_DIR_NAME "xccl-deepin_sw6_64") +ELSEIF(WITH_BDCENTOS) + SET(XPU_XRE_DIR_NAME "xre-bdcentos_x86_64") + SET(XPU_XDNN_DIR_NAME "xdnn-bdcentos_x86_64") + SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") +ELSEIF(WITH_UBUNTU) + SET(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64") + SET(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64") + SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") +ELSEIF(WITH_CENTOS) + SET(XPU_XRE_DIR_NAME "xre-centos7_x86_64") + SET(XPU_XDNN_DIR_NAME "xdnn-centos7_x86_64") + SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") + +ELSE () + SET(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64") + SET(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64") + SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") +ENDIF() + +SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210826") +SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_PACK_DEPENCE_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh" CACHE STRING "" FORCE) + +SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") +SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") +SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") +SET(XPU_INC_DIR "${THIRD_PARTY_PATH}/install/xpu/include") +SET(XPU_LIB_DIR "${THIRD_PARTY_PATH}/install/xpu/lib") + +SET(XPU_API_LIB "${XPU_LIB_DIR}/${XPU_API_LIB_NAME}") +SET(XPU_RT_LIB "${XPU_LIB_DIR}/${XPU_RT_LIB_NAME}") + +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${XPU_INSTALL_DIR}/lib") -INCLUDE_DIRECTORIES(${XPU_API_INC_DIR}) +FILE(WRITE ${XPU_DOWNLOAD_DIR}/CMakeLists.txt + "PROJECT(XPU)\n" + "cmake_minimum_required(VERSION 3.0)\n" + "install(DIRECTORY xpu/include xpu/lib \n" + " DESTINATION ${XPU_INSTALL_DIR})\n") + +ExternalProject_Add( + ${XPU_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${XPU_SOURCE_DIR} + DOWNLOAD_DIR ${XPU_DOWNLOAD_DIR} + DOWNLOAD_COMMAND wget ${XPU_PACK_DEPENCE_URL} + && bash pack_paddle_depence.sh ${XPU_XRE_URL} ${XPU_XRE_DIR_NAME} ${XPU_XDNN_URL} ${XPU_XDNN_DIR_NAME} ${XPU_XCCL_URL} ${XPU_XCCL_DIR_NAME} + + DOWNLOAD_NO_PROGRESS 1 + UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${XPU_INSTALL_ROOT} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${XPU_INSTALL_ROOT} + BUILD_BYPRODUCTS ${XPU_API_LIB} + BUILD_BYPRODUCTS ${XPU_RT_LIB} +) + +INCLUDE_DIRECTORIES(${XPU_INC_DIR}) ADD_LIBRARY(shared_xpuapi SHARED IMPORTED GLOBAL) set_property(TARGET shared_xpuapi PROPERTY IMPORTED_LOCATION "${XPU_API_LIB}") @@ -62,7 +84,7 @@ generate_dummy_static_lib(LIB_NAME "xpulib" GENERATOR "xpu.cmake") TARGET_LINK_LIBRARIES(xpulib ${XPU_API_LIB} ${XPU_RT_LIB}) -if (WITH_XPU_BKCL) +IF(WITH_XPU_BKCL) MESSAGE(STATUS "Compile with XPU BKCL!") ADD_DEFINITIONS(-DPADDLE_WITH_XPU_BKCL) @@ -71,15 +93,11 @@ if (WITH_XPU_BKCL) SET(XPU_BKCL_INC_DIR "${THIRD_PARTY_PATH}/install/xpu/include") INCLUDE_DIRECTORIES(${XPU_BKCL_INC_DIR}) TARGET_LINK_LIBRARIES(xpulib ${XPU_API_LIB} ${XPU_RT_LIB} ${XPU_BKCL_LIB}) -else(WITH_XPU_BKCL) - TARGET_LINK_LIBRARIES(xpulib ${XPU_API_LIB} ${XPU_RT_LIB} ) -endif(WITH_XPU_BKCL) - -if(NOT XPU_SDK_ROOT) - ADD_DEPENDENCIES(xpulib ${XPU_PROJECT}) -else() - ADD_CUSTOM_TARGET(extern_xpu DEPENDS xpulib) -endif() +ELSE(WITH_XPU_BKCL) + TARGET_LINK_LIBRARIES(xpulib ${XPU_API_LIB} ${XPU_RT_LIB}) +ENDIF(WITH_XPU_BKCL) + +ADD_DEPENDENCIES(xpulib ${XPU_PROJECT}) # Ensure that xpu/api.h can be included without dependency errors. file(GENERATE OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/.xpu_headers_dummy.cc CONTENT "") diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 955c917b2c..57f9d094ac 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -36,7 +36,8 @@ class LoDTensor; } // namespace framework } // namespace paddle #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/xpu_info.h" +#include "paddle/fluid/platform/xpu/xpu_info.h" +#include "paddle/fluid/platform/xpu/xpu_op_list.h" #endif #ifdef PADDLE_WITH_MKLDNN @@ -1228,6 +1229,8 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, // will be executed and a warning will be given at the same time. if (SupportGPU()) { expected_kernel_key.place_ = dev_ctx->GetPlace(); + } else if (SupportNPU()) { + expected_kernel_key.place_ = dev_ctx->GetPlace(); } else { expected_kernel_key.place_ = platform::CPUPlace(); LOG_FIRST_N(WARNING, 1) @@ -1251,8 +1254,10 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, } #endif #ifdef PADDLE_WITH_XPU - if (kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_)) { + if (is_xpu_place(expected_kernel_key.place_) && + (kernel_iter == kernels.end() || + !paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(type_))) { VLOG(3) << "missing XPU kernel: " << type_ << ", expected_kernel_key:" << expected_kernel_key << ", fallbacking to CPU one!"; @@ -1299,7 +1304,11 @@ void OperatorWithKernel::TransferInplaceVarsBack( auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto original_dims = original_tensor->dims(); original_tensor->ShareDataWith(*transformed_tensor); - original_tensor->Resize(original_dims); + // In order to solve the problem that the output latitude of NPU reshape + // operator is not changed when inplace. + if (type_ != "reshape2" && type_ != "reshape2_grad") { + original_tensor->Resize(original_dims); + } } } @@ -1525,7 +1534,12 @@ Scope* OperatorWithKernel::PrepareData( // the rest iterations to save the elapsed time. // We do not support skipping PrepareData in while block, because the Op's // input may be changed by subsequent Ops, which may cause an error. - if (pre_scope_ == &scope && new_scope == nullptr) { + + // For inference, ops that behind conditional branch aren't supported well, + // so disable prepare optimization conservatively. + bool force_prepare_data = HasAttr("inference_force_prepare_data") && + Attr("inference_force_prepare_data"); + if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) { need_prepare_data_ = false; } @@ -1549,10 +1563,10 @@ void OperatorWithKernel::ParseInputDataType( } else if (var->IsType()) { t = &(var->Get().value()); } else if (var->IsType()) { - auto t_arr = var->Get(); - for (size_t j = 0; j < t_arr.size(); j++) { - if (t_arr[j].IsInitialized()) { - t = &(t_arr[j]); + auto t_arr = &var->Get(); + for (size_t j = 0; j < t_arr->size(); j++) { + if (t_arr->at(j).IsInitialized()) { + t = &(t_arr->at(j)); } } } diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 2a3b6424d4..8f45cd0fa6 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -15,7 +15,12 @@ #include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/imperative/infer_shape_context.h" +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/platform/xpu/xpu_op_list.h" +#endif +DECLARE_bool(check_nan_inf); namespace paddle { namespace imperative { @@ -88,7 +93,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -100,14 +106,17 @@ PreparedOp PrepareImpl(const NameVarMap& ins, // Const qualifier of Attrs had to be discarded to overwrite it. if (FLAGS_use_mkldnn) { auto& mutable_op_attrs = const_cast(op.Attrs()); - mutable_op_attrs = attrs; + mutable_op_attrs = default_attrs; + for (auto& attr : attrs) { + mutable_op_attrs[attr.first] = attr.second; + } } #endif // 1. get expected kernel key - auto expected_kernel_key = - op.GetExpectedKernelType(DygraphExecutionContext( - op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); + auto expected_kernel_key = op.GetExpectedKernelType( + DygraphExecutionContext(op, framework::Scope(), *dev_ctx, ctx, + ins, outs, attrs, default_attrs)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; // 2. check if op[type] has kernel registered. @@ -122,8 +131,23 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto& kernels = kernels_iter->second; auto kernel_iter = kernels.find(expected_kernel_key); #ifdef PADDLE_WITH_XPU + if (is_xpu_place(expected_kernel_key.place_) && + (kernel_iter == kernels.end() || + !paddle::platform::is_xpu_support_op(op.Type(), expected_kernel_key) || + paddle::platform::is_in_xpu_black_list(op.Type()))) { + VLOG(3) << "missing XPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; + expected_kernel_key.place_ = platform::CPUPlace(); + kernel_iter = kernels.find(expected_kernel_key); + } +#endif +#ifdef PADDLE_WITH_ASCEND_CL if (kernel_iter == kernels.end() && - is_xpu_place(expected_kernel_key.place_)) { + is_npu_place(expected_kernel_key.place_)) { + VLOG(3) << "missing NPU kernel: " << op.Type() + << ", expected_kernel_key:" << expected_kernel_key + << ", fallbacking to CPU one!"; expected_kernel_key.place_ = platform::CPUPlace(); kernel_iter = kernels.find(expected_kernel_key); } @@ -145,16 +169,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareImpl(ins, outs, op, place, attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { + return PrepareImpl(ins, outs, op, place, attrs, default_attrs); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareImpl(ins, outs, op, place, attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { + return PrepareImpl(ins, outs, op, place, attrs, + default_attrs); } template @@ -163,17 +190,23 @@ static void PreparedOpRunImpl( const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx, const NameVarMap& ins, - const NameVarMap& outs, const framework::AttributeMap& attrs) { + const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { // TODO(zjl): remove scope in dygraph framework::Scope scope; DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, - op.Type()); + &default_attrs, op.Type()); static_cast(op).InferShape( &infer_shape_ctx); func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, - attrs)); + attrs, default_attrs)); + + if (FLAGS_check_nan_inf) { + framework::details::CheckOpHasNanOrInfInDygraph( + op.Type(), outs, dev_ctx->GetPlace()); + } /** * [ Why need handle complex gradient to real gradient? ] @@ -194,16 +227,18 @@ static void PreparedOpRunImpl( void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, - outs, attrs); + outs, attrs, default_attrs); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, - ins, outs, attrs); + ins, outs, attrs, default_attrs); } } // namespace imperative diff --git a/paddle/fluid/operators/label_smooth_op_xpu.cc b/paddle/fluid/operators/label_smooth_op_xpu.cc new file mode 100644 index 0000000000..6b63507539 --- /dev/null +++ b/paddle/fluid/operators/label_smooth_op_xpu.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/label_smooth_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LabelSmoothXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* in_t = ctx.Input("X"); + auto* dist_t = ctx.Input("PriorDist"); + auto label_dim = in_t->dims()[in_t->dims().size() - 1]; + auto ptr = out_t->mutable_data(ctx.GetPlace()); + + auto epsilon = ctx.Attr("epsilon"); + auto& dev_ctx = ctx.template device_context(); + if (dist_t) { + PADDLE_THROW( + platform::errors::External("XPU doesn't support dist label smooth")); + } else { + int r = xpu::label_smooth(dev_ctx.x_context(), in_t->data(), ptr, + in_t->numel(), epsilon, label_dim); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(label_smooth) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + label_smooth, + ops::LabelSmoothXPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py new file mode 100644 index 0000000000..5a827c1beb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_label_smooth_op_xpu.py @@ -0,0 +1,64 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import sys +sys.path.append("..") +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +class TestLabelSmoothOp(XPUOpTest): + def config(self): + self.op_type = "label_smooth" + self.epsilon = 0.1 + self.use_xpu = True + batch_size, self.label_dim = 10, 12 + self.label = np.zeros((batch_size, self.label_dim)).astype("float32") + nonzero_index = np.random.randint(self.label_dim, size=(batch_size)) + self.label[np.arange(batch_size), nonzero_index] = 1 + + def setUp(self): + self.config() + smoothed_label = (1 - self.epsilon + ) * self.label + self.epsilon / self.label_dim + self.inputs = {'X': self.label} + self.attrs = {'epsilon': self.epsilon} + self.outputs = {'Out': smoothed_label} + + def test_check_output(self): + if not paddle.is_compiled_with_xpu(): + return + self.check_output_with_place(paddle.XPUPlace(0), atol=1e-6) + + def test_check_grad(self): + return + + +class TestLabelSmoothOp3D(TestLabelSmoothOp): + def setUp(self): + super(TestLabelSmoothOp3D, self).setUp() + self.inputs['X'] = self.inputs['X'].reshape( + [2, -1, self.inputs['X'].shape[-1]]) + self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X'] + .shape) + + +if __name__ == '__main__': + unittest.main() -- Gitee