diff --git a/graph/aligned_ptr.cc b/graph/aligned_ptr.cc index d3c58a21b6a2e39f9491235263c7b30df6e95526..60c9a3a3e3e47057c4a37eee437875f725d814fb 100644 --- a/graph/aligned_ptr.cc +++ b/graph/aligned_ptr.cc @@ -88,11 +88,11 @@ std::shared_ptr AlignedPtr::BuildFromAllocFunc(const AlignedPtr::All return aligned_ptr; } -std::shared_ptr AlignedPtr::BuildFromData(uint8_t *data, const AlignedPtr::Deleter &delete_func) { - if (data == nullptr || delete_func == nullptr) { - REPORT_INNER_ERROR("E19999", "data is nullptr or delete_func is nullptr"); - GELOGE(FAILED, "[Check][Param] data/delete_func is null"); - return nullptr; +std::shared_ptr AlignedPtr::BuildFromAllocFunc(const AlignedPtr::Allocator &alloc_func) { + if ((alloc_func == nullptr)) { + REPORT_INNER_ERROR("E19999", "alloc_func is nullptr, check invalid"); + GELOGE(FAILED, "[Check][Param] alloc_func is null"); + return nullptr; } auto aligned_ptr = MakeShared(); if (aligned_ptr == nullptr) { @@ -100,8 +100,13 @@ std::shared_ptr AlignedPtr::BuildFromData(uint8_t *data, const Align GELOGE(INTERNAL_ERROR, "[Create][AlignedPtr] make shared for AlignedPtr failed"); return nullptr; } - aligned_ptr->base_.reset(data); - aligned_ptr->base_.get_deleter() = delete_func; + aligned_ptr->base_.reset(); + alloc_func(aligned_ptr->base_); + if (aligned_ptr->base_ == nullptr) { + REPORT_CALL_ERROR("E19999", "allocate for AlignedPtr failed"); + GELOGE(FAILED, "[Call][AllocFunc] allocate for AlignedPtr failed"); + return nullptr; + } aligned_ptr->aligned_addr_ = aligned_ptr->base_.get(); return aligned_ptr; } diff --git a/graph/ge_tensor.cc b/graph/ge_tensor.cc index dbaf837ae0eea8b76fa72cd87d7416e1d3fd2d0b..025f5a4866a78da9ba90ef6b82b7da1be0078e70 100644 --- a/graph/ge_tensor.cc +++ b/graph/ge_tensor.cc @@ -733,7 +733,10 @@ graphStatus TensorData::SetData(uint8_t *data, size_t size, const AlignedPtr::De return GRAPH_FAILED; } length_ = size; - aligned_ptr_ = AlignedPtr::BuildFromData(data, delete_fuc); + aligned_ptr_ = AlignedPtr::BuildFromAllocFunc([&](std::unique_ptr &ptr) { + ptr.reset(data); + ptr.get_deleter() = delete_fuc; + }); return GRAPH_SUCCESS; } @@ -870,13 +873,10 @@ void GeTensor::BuildAlignerPtrWithProtoData() { tensor_data_.length_ = proto_msg->data().size(); tensor_data_.aligned_ptr_.reset(); tensor_data_.aligned_ptr_ = - AlignedPtr::BuildFromAllocFunc([&proto_msg](std::unique_ptr &ptr) { - ptr.reset(const_cast( - reinterpret_cast(proto_msg->data().data()))); - }, - [](uint8_t *ptr) { - ptr = nullptr; - }); + AlignedPtr::BuildFromAllocFunc([&proto_msg](std::unique_ptr &ptr) { + ptr.reset(const_cast(reinterpret_cast(proto_msg->data().data()))); + ptr.get_deleter() = [](uint8_t *) {}; + }); } GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); } diff --git a/inc/graph/aligned_ptr.h b/inc/graph/aligned_ptr.h index 3924f3d85304b15047a16e7228d5e828c2baef57..2146266e4cd0c4cb79d472e31d6ee9219f318c0a 100644 --- a/inc/graph/aligned_ptr.h +++ b/inc/graph/aligned_ptr.h @@ -39,8 +39,7 @@ class AlignedPtr { static std::shared_ptr BuildFromAllocFunc(const AlignedPtr::Allocator &alloc_func, const AlignedPtr::Deleter &delete_func); - static std::shared_ptr BuildFromData(uint8_t *data, - const AlignedPtr::Deleter &delete_func); /*lint !e148*/ + static std::shared_ptr BuildFromAllocFunc(const AlignedPtr::Allocator &alloc_func); /*lint !e148*/ private: std::unique_ptr base_ = nullptr; uint8_t *aligned_addr_ = nullptr; diff --git a/tests/ut/graph/testcase/aligned_ptr_unittest.cc b/tests/ut/graph/testcase/aligned_ptr_unittest.cc index f9327a1975319fd48d87f5a45ccb8d6712eac05f..f13e8082f0cc5afdc90ad613594567deea18d26f 100644 --- a/tests/ut/graph/testcase/aligned_ptr_unittest.cc +++ b/tests/ut/graph/testcase/aligned_ptr_unittest.cc @@ -44,14 +44,24 @@ namespace ge ASSERT_EQ(output_base, 0); } - TEST_F(UtestAlignedPtr, BuildFromData_success) { - auto deleter = [](uint8_t *ptr) { - delete ptr; - ptr = nullptr; + TEST_F(UtestAlignedPtr, BuildFromAllocFunc_success) { + uint8_t *data = new uint8_t(10); + auto allocate = [&data](std::unique_ptr &ptr) { + ptr.reset(data); + ptr.get_deleter() = [](uint8_t *addr) { + delete addr; + addr = nullptr; + }; }; - uint8_t* data_ptr = new uint8_t(10); - auto aligned_ptr = AlignedPtr::BuildFromData(data_ptr, deleter); + auto aligned_ptr = AlignedPtr::BuildFromAllocFunc(allocate); uint8_t result = *(aligned_ptr->Get()); ASSERT_EQ(result, 10); } + + TEST_F(UtestAlignedPtr, BuildFromAllocFunc_failed) { + auto allocate = nullptr; + auto aligned_ptr = AlignedPtr::BuildFromAllocFunc(allocate); + ASSERT_EQ(aligned_ptr, nullptr); + } + }