From 504d1f17fa4ef187d942707cb2c4e3647cf27c0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Mon, 29 Sep 2025 06:59:16 +0000 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!6955=20?= =?UTF-8?q?:=20fix=20repeat=20infomation'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- base/CMakeLists.txt | 2 +- base/context_builder/context_holder.cc | 6 +- base/context_builder/context_holder_builder.h | 2 +- .../op_context_builder_base.cc | 244 --------- .../op_context_builder_impl.cc | 81 +-- .../context_builder/op_context_builder_impl.h | 20 +- .../op_infer_datatype_context_builder.cc | 88 +++- .../op_infer_shape_context_builder.cc | 68 ++- .../op_infer_shape_range_context_builder.cc | 78 +-- base/context_builder/op_info.cc | 106 ++++ base/context_builder/op_info.h | 43 -- base/context_builder/op_info_impl.h | 131 +++++ .../op_kernel_run_context_builder.cc | 46 +- .../op_tiling_context_builder.cc | 74 +-- .../op_tiling_parse_context_builder.cc | 57 +-- inc/base/attr/attrs_to_buffer.h | 6 +- .../base/context_builder/context_holder.h | 8 +- .../context_builder/op_context_builder_base.h | 104 ---- .../op_infer_datatype_context_builder.h | 69 ++- .../op_infer_shape_context_builder.h | 53 +- .../op_infer_shape_range_context_builder.h | 54 +- inc/external/base/context_builder/op_info.h | 101 ++++ .../op_kernel_run_context_builder.h | 54 +- .../op_tiling_context_builder.h | 69 ++- .../op_tiling_parse_context_builder.h | 67 +-- .../base/testcase/context_builder_unittest.cc | 477 ++++++------------ 26 files changed, 896 insertions(+), 1212 deletions(-) delete mode 100644 base/context_builder/op_context_builder_base.cc create mode 100644 base/context_builder/op_info.cc delete mode 100644 base/context_builder/op_info.h create mode 100644 base/context_builder/op_info_impl.h delete mode 100644 inc/external/base/context_builder/op_context_builder_base.h create mode 100644 inc/external/base/context_builder/op_info.h diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index cae15d0d4b..e3320ae7e8 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -145,13 +145,13 @@ set(SRC_LIST "runtime/tiling_data.cc" context_builder/op_context_builder_impl.cc context_builder/context_holder.cc + context_builder/op_info.cc context_builder/op_kernel_run_context_builder.cc context_builder/op_tiling_context_builder.cc context_builder/op_tiling_parse_context_builder.cc context_builder/op_infer_datatype_context_builder.cc context_builder/op_infer_shape_context_builder.cc context_builder/op_infer_shape_range_context_builder.cc - context_builder/op_context_builder_base.cc ) ############ libmetadef.so ############ diff --git a/base/context_builder/context_holder.cc b/base/context_builder/context_holder.cc index 74eec0c193..f71f4803c0 100644 --- a/base/context_builder/context_holder.cc +++ b/base/context_builder/context_holder.cc @@ -11,17 +11,17 @@ #include "common/checker.h" namespace gert { -void *ContextHolderVoid::GetContext() const { +void *ContextHolderVoid::GetContext() { GE_ASSERT_NOTNULL(ctx_holder_impl_, "ctx_holder_impl_ is null"); return ctx_holder_impl_->GetContext(); } ContextHolderVoid::ContextHolderVoid() = default; ContextHolderVoid::~ContextHolderVoid() = default; -ContextHolderVoid::ContextHolderVoid(ContextHolderVoid &&other) noexcept { +ContextHolderVoid::ContextHolderVoid(ContextHolderVoid &&other) { ctx_holder_impl_ = std::move(other.ctx_holder_impl_); } -ContextHolderVoid &ContextHolderVoid::operator=(ContextHolderVoid &&other) noexcept { +ContextHolderVoid &ContextHolderVoid::operator=(ContextHolderVoid &&other) { if (this != &other) { ctx_holder_impl_ = std::move(other.ctx_holder_impl_); } diff --git a/base/context_builder/context_holder_builder.h b/base/context_builder/context_holder_builder.h index 7aabbf6492..78ee901b6b 100644 --- a/base/context_builder/context_holder_builder.h +++ b/base/context_builder/context_holder_builder.h @@ -14,7 +14,7 @@ namespace gert { class ContextHolderBuilder { public: static ContextHolderVoid Create(std::unique_ptr &&ctx_holder_impl) { - GE_ASSERT_NOTNULL(ctx_holder_impl, "ctx_holder_impl is null while creating ContextHolder"); + GE_ASSERT_NOTNULL(ctx_holder_impl, "ctx_holder_impl is null whle creating ContextHolder"); ContextHolderVoid holder; holder.ctx_holder_impl_ = std::move(ctx_holder_impl); return holder; diff --git a/base/context_builder/op_context_builder_base.cc b/base/context_builder/op_context_builder_base.cc deleted file mode 100644 index 74a12169d0..0000000000 --- a/base/context_builder/op_context_builder_base.cc +++ /dev/null @@ -1,244 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "exe_graph/runtime/tensor.h" -#include "common/ge_common/util.h" -#include "graph/debug/ge_util.h" -#include "base/context_builder/op_context_builder_base.h" -#include "base/context_builder/op_tiling_context_builder.h" -#include "base/context_builder/op_context_builder_impl.h" -#include "base/context_builder/op_infer_shape_context_builder.h" -#include "base/context_builder/op_infer_datatype_context_builder.h" -#include "base/context_builder/op_tiling_parse_context_builder.h" -#include "base/context_builder/op_infer_shape_range_context_builder.h" -#include "base/context_builder/op_kernel_run_context_builder.h" - -namespace gert { -template -OpContextBuilderBase::OpContextBuilderBase() : impl_(ge::ComGraphMakeUnique()) {} -template -OpContextBuilderBase::~OpContextBuilderBase() = default; - -template -T &OpContextBuilderBase::OpType(const ge::AscendString &op_type) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().op_type = op_type.GetString(); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::OpName(const ge::AscendString &op_name) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().op_name = op_name.GetString(); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::IONum(size_t input_ir_num, size_t output_ir_num) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - auto &op_info = impl_->MutableOpInfo(); - if (!op_info.input_instance.empty() || !op_info.output_instance.empty()) { - GELOGW("IO has been set. Set IO Num failed!"); - return static_cast(*this); // 已经设置过输入输出, 无需不允许再次设置 - } - op_info.input_ir_num = input_ir_num; - op_info.output_ir_num = output_ir_num; - op_info.input_tensor_descs.resize(op_info.input_ir_num, ContextTensorDesc()); - op_info.output_tensor_descs.resize(op_info.output_ir_num, ContextTensorDesc()); - op_info.input_instance.resize(op_info.input_ir_num, 1); - op_info.output_instance.resize(op_info.output_ir_num, 1); - op_info.input_instance_num = op_info.input_ir_num; - op_info.output_instance_num = op_info.output_ir_num; - return static_cast(*this); -} - -template -T &OpContextBuilderBase::IOInstanceNum(const std::vector &input_instance_num, - const std::vector &output_instance_num) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - auto &op_info = impl_->MutableOpInfo(); - op_info.input_instance = input_instance_num; - op_info.output_instance = output_instance_num; - op_info.input_ir_num = input_instance_num.size(); - op_info.output_ir_num = output_instance_num.size(); - op_info.input_instance_num = 0U; - op_info.output_instance_num = 0U; - for (const auto &num : op_info.input_instance) { - op_info.input_instance_num += num; - } - op_info.input_tensor_descs.resize(op_info.input_instance_num, ContextTensorDesc()); - for (const auto &num : op_info.output_instance) { - op_info.output_instance_num += num; - } - op_info.output_tensor_descs.resize(op_info.output_instance_num, ContextTensorDesc()); - return static_cast(*this); -} - -template -ge::DataType &OpContextBuilderBase::MutableInputDataType(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().input_tensor_descs.size()) { - static ge::DataType default_dtype = ge::DT_MAX; - GELOGE(ge::PARAM_INVALID, "Input index %zu is out of range, input tensor desc size is %zu", index, - impl_->MutableOpInfo().input_tensor_descs.size()); - return default_dtype; - } - return impl_->MutableOpInfo().input_tensor_descs[index].dtype; -} - -template -ge::Format &OpContextBuilderBase::MutableInputOriginalFormat(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().input_tensor_descs.size()) { - static ge::Format default_format = ge::FORMAT_MAX; - GELOGE(ge::PARAM_INVALID, "Input index %zu is out of range, input tensor desc size is %zu", index, - impl_->MutableOpInfo().input_tensor_descs.size()); - return default_format; - } - return impl_->MutableOpInfo().input_tensor_descs[index].origin_format; -} - -template -ge::Format &OpContextBuilderBase::MutableInputStorageFormat(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().input_tensor_descs.size()) { - static ge::Format default_format = ge::FORMAT_MAX; - GELOGE(ge::PARAM_INVALID, "Input index %zu is out of range, input tensor desc size is %zu", index, - impl_->MutableOpInfo().input_tensor_descs.size()); - return default_format; - } - return impl_->MutableOpInfo().input_tensor_descs[index].storage_format; -} - -template -gert::ExpandDimsType &OpContextBuilderBase::MutableInputExpandDimsType(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().input_tensor_descs.size()) { - static gert::ExpandDimsType default_expand_dims_type; - GELOGE(ge::PARAM_INVALID, "Input index %zu is out of range, input tensor desc size is %zu", index, - impl_->MutableOpInfo().input_tensor_descs.size()); - return default_expand_dims_type; - } - return impl_->MutableOpInfo().input_tensor_descs[index].expand_dims_type; -} - -template -DataType &OpContextBuilderBase::MutableOutputDataType(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().output_tensor_descs.size()) { - static ge::DataType default_dtype = ge::DT_MAX; - GELOGE(ge::PARAM_INVALID, "Output index %zu is out of range, output tensor desc size is %zu", index, - impl_->MutableOpInfo().output_tensor_descs.size()); - return default_dtype; - } - return impl_->MutableOpInfo().output_tensor_descs[index].dtype; -} - -template -ge::Format &OpContextBuilderBase::MutableOutputOriginalFormat(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().output_tensor_descs.size()) { - static ge::Format default_format = ge::FORMAT_MAX; - GELOGE(ge::PARAM_INVALID, "Output index %zu is out of range, output tensor desc size is %zu", index, - impl_->MutableOpInfo().output_tensor_descs.size()); - return default_format; - } - return impl_->MutableOpInfo().output_tensor_descs[index].origin_format; -} - -template -ge::Format &OpContextBuilderBase::MutableOutputStorageFormat(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().output_tensor_descs.size()) { - static ge::Format default_format = ge::FORMAT_MAX; - GELOGE(ge::PARAM_INVALID, "Output index %zu is out of range, output tensor desc size is %zu", index, - impl_->MutableOpInfo().output_tensor_descs.size()); - return default_format; - } - return impl_->MutableOpInfo().output_tensor_descs[index].storage_format; -} - -template -gert::ExpandDimsType &OpContextBuilderBase::MutableOutputExpandDimsType(size_t index) { - if (impl_ == nullptr || index >= impl_->MutableOpInfo().output_tensor_descs.size()) { - static gert::ExpandDimsType default_expand_dims_type; - GELOGE(ge::PARAM_INVALID, "Output index %zu is out of range, output tensor desc size is %zu", index, - impl_->MutableOpInfo().output_tensor_descs.size()); - return default_expand_dims_type; - } - return impl_->MutableOpInfo().output_tensor_descs[index].expand_dims_type; -} - -template -T &OpContextBuilderBase::AppendAttr(bool attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom(attr)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(int64_t attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom(attr)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(float attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom(attr)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(const ge::AscendString &attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom(attr.GetString())); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(const std::vector &attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom>(attr)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(const std::vector &attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom>(attr)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(const std::vector &attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom>(attr)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(const std::vector &attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - std::vector attr_str; - for (auto &item : attr) { - attr_str.emplace_back(item.GetString()); - } - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom>(attr_str)); - return static_cast(*this); -} - -template -T &OpContextBuilderBase::AppendAttr(const std::vector> &attr) { - GE_CHECK_NOTNULL_EXEC(impl_, return static_cast(*this)); - impl_->MutableOpInfo().attrs.emplace_back(ge::AnyValue::CreateFrom>>(attr)); - return static_cast(*this); -} -template class OpContextBuilderBase; -template class OpContextBuilderBase; -template class OpContextBuilderBase; -template class OpContextBuilderBase; -template class OpContextBuilderBase; -template class OpContextBuilderBase; -} // namespace gert diff --git a/base/context_builder/op_context_builder_impl.cc b/base/context_builder/op_context_builder_impl.cc index 16526ff3c9..55d5c4d00f 100644 --- a/base/context_builder/op_context_builder_impl.cc +++ b/base/context_builder/op_context_builder_impl.cc @@ -9,30 +9,45 @@ #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/op_kernel_run_context_builder.h" #include "base/context_builder/op_tiling_context_builder.h" +#include "base/context_builder/op_info_impl.h" #include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_util.h" #include "base/attr/attrs_to_buffer.h" namespace gert { +namespace { +const std::string kATTR_NAME_RESHAPE_TYPE_MASK = "_reshape_type_mask"; +} ge::graphStatus ContextBuilderImpl::SetCompileTimeTd(const ContextTensorDesc &desc, CompileTimeTensorDesc &td) { - td.SetDataType(desc.dtype); - td.SetOriginFormat(desc.origin_format); - td.SetStorageFormat(desc.storage_format); - td.SetExpandDimsType(desc.expand_dims_type); + td.SetDataType(desc.dtype_); + td.SetOriginFormat(desc.origin_format_); + td.SetStorageFormat(desc.storage_format_); + int64_t reshape_type_mask = 0; + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "OpInfoImpl is null"); + for (auto &attr : op_info_impl->attrs_) { + if (attr.first == kATTR_NAME_RESHAPE_TYPE_MASK.c_str()) { + if (attr.second.GetValue(reshape_type_mask) == 0) { + td.SetExpandDimsType(ExpandDimsType(reshape_type_mask)); + } + } + } return ge::GRAPH_SUCCESS; } ge::graphStatus ContextBuilderImpl::InitIOInstanceInfo(ComputeNodeInfo &compute_node_info) { size_t input_index = 0U; - for (size_t i = 0U; i < op_info_.input_instance.size(); ++i) { - size_t instance_num = op_info_.input_instance[i]; + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "OpInfoImpl is null"); + for (size_t i = 0U; i < op_info_impl->input_instance_.size(); ++i) { + size_t instance_num = op_info_impl->input_instance_[i]; compute_node_info.MutableInputInstanceInfo(i)->SetInstantiationNum(instance_num); compute_node_info.MutableInputInstanceInfo(i)->SetInstanceStart(input_index); input_index += instance_num; } size_t output_index = 0U; - for (size_t i = 0U; i < op_info_.output_instance.size(); ++i) { - size_t instance_num = op_info_.output_instance[i]; + for (size_t i = 0U; i < op_info_impl->output_instance_.size(); ++i) { + size_t instance_num = op_info_impl->output_instance_[i]; compute_node_info.MutableOutputInstanceInfo(i)->SetInstantiationNum(instance_num); compute_node_info.MutableOutputInstanceInfo(i)->SetInstanceStart(output_index); output_index += instance_num; @@ -41,15 +56,17 @@ ge::graphStatus ContextBuilderImpl::InitIOInstanceInfo(ComputeNodeInfo &compute_ } ge::graphStatus ContextBuilderImpl::InitCompileTimeTD(ComputeNodeInfo &compute_node_info) { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "OpInfoImpl is null"); for (size_t i = 0U; i < compute_node_info.GetInputsNum(); ++i) { - const auto &ctx_io_desc = op_info_.input_tensor_descs[i]; + const auto &ctx_io_desc = op_info_impl->input_tensor_descs_[i]; auto td = compute_node_info.MutableInputTdInfo(i); GE_ASSERT_NOTNULL(td, "tensor desc in compute node info is null"); SetCompileTimeTd(ctx_io_desc, *td); } for (size_t i = 0U; i < compute_node_info.GetOutputsNum(); ++i) { - const auto &ctx_io_desc = op_info_.output_tensor_descs[i]; + const auto &ctx_io_desc = op_info_impl->output_tensor_descs_[i]; auto td = compute_node_info.MutableOutputTdInfo(i); GE_ASSERT_NOTNULL(td, "tensor desc in compute node info is null"); SetCompileTimeTd(ctx_io_desc, *td); @@ -58,24 +75,24 @@ ge::graphStatus ContextBuilderImpl::InitCompileTimeTD(ComputeNodeInfo &compute_n } std::unique_ptr ContextBuilderImpl::CreateComputeNodeInfoImpl(const std::unique_ptr &attr_buf, - size_t attr_size, - const OpInfo &op_info, + const size_t attr_size, + const OpInfoImpl &op_info, std::vector &string_pool, size_t &total_size) { - const size_t ir_input_num = op_info.input_ir_num; - const size_t ir_output_num = op_info.output_ir_num; - const size_t input_num = op_info.input_instance_num; - const size_t output_num = op_info.output_instance_num; - GELOGD("opinfo: %s(%s), ir_input_num:%zu, ir_output_num:%zu, input_num:%zu, output_num:%zu.", op_info.op_name.c_str(), - op_info.op_type.c_str(), ir_input_num, ir_output_num, input_num, output_num); + const size_t ir_input_num = op_info.input_ir_num_; + const size_t ir_output_num = op_info.output_ir_num_; + const size_t input_num = op_info.input_instance_num_; + const uint32_t output_num = op_info.output_instance_num_; + GELOGD("opinfo: %s(%s), ir_input_num:%zu, ir_output_num:%zu, input_num:%zu, output_num:%u.", op_info.op_name_.c_str(), + op_info.op_type_.c_str(), ir_input_num, ir_output_num, input_num, output_num); GE_ASSERT_SUCCESS(ComputeNodeInfo::CalcSize(ir_input_num, ir_output_num, input_num, output_num, total_size)); GE_ASSERT_TRUE(!ge::AddOverflow(total_size, attr_size, total_size)); auto compute_node_info_holder = ge::ComGraphMakeUnique(total_size); GE_ASSERT_NOTNULL(compute_node_info_holder, "Create compute node info holder failed"); auto idx = string_pool.size(); - string_pool.emplace_back(op_info.op_name); - string_pool.emplace_back(op_info.op_type); + string_pool.emplace_back(op_info.op_name_); + string_pool.emplace_back(op_info.op_type_); auto name_ptr = string_pool[idx].c_str(); auto type_ptr = string_pool[idx + 1].c_str(); @@ -83,10 +100,10 @@ std::unique_ptr ContextBuilderImpl::CreateComputeNodeInfoImpl(const s compute_node_info->Init(ir_input_num, ir_output_num, input_num, output_num, attr_size, name_ptr, type_ptr); auto ret = InitIOInstanceInfo(*compute_node_info); - GE_ASSERT_SUCCESS(ret, "Init input instance info for node:%s failed.", op_info.op_name.c_str()); + GE_ASSERT_SUCCESS(ret, "Init input instance info for node:%s failed.", op_info.op_name_.c_str()); ret = InitCompileTimeTD(*compute_node_info); - GE_ASSERT_SUCCESS(ret, "Init compile time tensor desc for node:%s failed.", op_info.op_name.c_str()); + GE_ASSERT_SUCCESS(ret, "Init compile time tensor desc for node:%s failed.", op_info.op_name_.c_str()); auto attr = compute_node_info->MutableAttrs(); const auto offset = ge::PtrToPtr(attr) - compute_node_info_holder.get(); @@ -100,27 +117,25 @@ std::unique_ptr ContextBuilderImpl::CreateComputeNodeInfoImpl(const s GE_ASSERT_SUCCESS(ret, "memcpy_s failed, copy size is %zu, dst size is %zu", attr_size, (total_size - offset - outputs_ins_info_size)); GELOGI("Node %s, compute_node_info attr_size %zu, outputs_ins_info_size:%zu, offset:%zu, total_size:%zu.", - op_info.op_name.c_str(), attr_size, outputs_ins_info_size, offset, total_size); + op_info.op_name_.c_str(), attr_size, outputs_ins_info_size, offset, total_size); return compute_node_info_holder; } ge::graphStatus ContextBuilderImpl::CreateComputeNodeInfo(ContextHolderImpl &holder) { - GE_ASSERT_TRUE((!op_info_.op_type.empty()) && (!op_info_.op_name.empty()) && (op_info_.input_ir_num != 0) && - (op_info_.output_ir_num != 0), - "Invalid params, op_type: %s, op_name: %s, input_num: %u, output_num: %u", op_info_.op_type.c_str(), - op_info_.op_name.c_str(), op_info_.input_ir_num, op_info_.output_ir_num); size_t attr_size; - auto attr_buf = bg::CreateAttrBufferWithAttrs(op_info_.attrs, attr_size); + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "OpInfoImpl is null"); + auto attr_buf = bg::CreateAttrBufferWithAttrs(op_info_impl->attrs_, attr_size); size_t total_size = 0U; holder.compute_node_info_holder_ = - CreateComputeNodeInfoImpl(attr_buf, attr_size, op_info_, holder.string_pool_, total_size); + CreateComputeNodeInfoImpl(attr_buf, attr_size, *op_info_impl, holder.string_pool_, total_size); return ge::GRAPH_SUCCESS; } ge::graphStatus ContextBuilderImpl::BuildCtx(ContextHolderImpl &holder) { - const auto in_size = input_values_.size(); - const auto out_size = output_values_.size(); - const auto io_size = in_size + out_size; - const auto size = sizeof(KernelRunContext) + sizeof(Chain *) * (io_size); + auto in_size = input_values_.size(); + auto out_size = output_values_.size(); + auto io_size = in_size + out_size; + size_t size = sizeof(KernelRunContext) + sizeof(Chain *) * (io_size); holder.context_holder_.reset(new (std::nothrow) uint8_t[size]); GE_ASSERT_NOTNULL(holder.context_holder_, "Create context holder failed."); diff --git a/base/context_builder/op_context_builder_impl.h b/base/context_builder/op_context_builder_impl.h index 4ed5e22764..7a9dbcf20f 100644 --- a/base/context_builder/op_context_builder_impl.h +++ b/base/context_builder/op_context_builder_impl.h @@ -17,14 +17,14 @@ #include "exe_graph/runtime/continuous_vector.h" #include "graph/debug/ge_log.h" #include "common/checker.h" -#include "op_info.h" +#include "base/context_builder/op_info_impl.h" #include "base/runtime/runtime_attrs_def.h" #include "graph/debug/ge_util.h" namespace gert { struct TilingInfo { const void *compile_info_ = nullptr; const void *platform_info_ = nullptr; - std::pair tiling_data_ = {nullptr, nullptr}; + const void *tiling_data_ = nullptr; int32_t deterministic_ = 0; const gert::ContinuousVector *workspace_ = nullptr; }; @@ -70,14 +70,14 @@ class ContextHolderImpl { class ContextBuilderImpl { public: - ContextBuilderImpl() = default; + ContextBuilderImpl() {} virtual ~ContextBuilderImpl() = default; - OpInfo &MutableOpInfo() { + OpInfo &GetOpInfo() { return op_info_; } - void SetCompiledInfo(const void *compile_info) { + void SetCompileInfo(const void *compile_info) { tiling_info_.compile_info_ = compile_info; } @@ -92,11 +92,8 @@ class ContextBuilderImpl { tiling_parse_info_.compiled_json_ = compiled_json; } - void SetTilingData(const void *tiling_data, gert::Chain::Deleter deleter) { - if (tiling_info_.tiling_data_.first != nullptr && tiling_info_.tiling_data_.second != nullptr) { - tiling_info_.tiling_data_.second(const_cast(tiling_info_.tiling_data_.first)); - } - tiling_info_.tiling_data_ = {tiling_data, deleter}; + void SetTilingData(const void *tiling_data) { + tiling_info_.tiling_data_ = tiling_data; } void SetWorkspace(const gert::ContinuousVector *workspace) { @@ -122,7 +119,7 @@ class ContextBuilderImpl { ge::graphStatus BuildRTOutputShapes(ContextHolderImpl &holder); ge::graphStatus InitCompileTimeTD(ComputeNodeInfo &compute_node_info); std::unique_ptr CreateComputeNodeInfoImpl(const std::unique_ptr &attr_buf, - size_t attr_size, const OpInfo &op_info, + const size_t attr_size, const OpInfoImpl &op_info, std::vector &string_pool, size_t &total_size); ge::graphStatus CreateComputeNodeInfo(ContextHolderImpl &holder); ge::graphStatus InitIOInstanceInfo(ComputeNodeInfo &compute_node_info); @@ -133,7 +130,6 @@ class ContextBuilderImpl { TilingParseInfo tiling_parse_info_; std::vector> input_values_; std::vector> output_values_; - bool use_data_type_ptr_{false}; }; } // namespace gert diff --git a/base/context_builder/op_infer_datatype_context_builder.cc b/base/context_builder/op_infer_datatype_context_builder.cc index 58df6b4f8e..0ffd2ff7b1 100644 --- a/base/context_builder/op_infer_datatype_context_builder.cc +++ b/base/context_builder/op_infer_datatype_context_builder.cc @@ -10,7 +10,8 @@ #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/context_holder_builder.h" #include "common/ge_common/util.h" -#include "op_info.h" +#include "base/context_builder/op_info.h" +#include "base/context_builder/op_info_impl.h" #include #include "securec.h" #include "graph/debug/ge_util.h" @@ -22,50 +23,75 @@ class OpInferDataTypeContextBuilderImpl : public ContextBuilderImpl { ~OpInferDataTypeContextBuilderImpl() override = default; std::unique_ptr BuildInferDataTypeContext() { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "op_info_impl is nullptr"); + GE_ASSERT_TRUE(op_info_impl->CheckParams(), "CheckParams failed, op_type: %s, op_name: %s", + op_info_impl->op_type_.c_str(), op_info_impl->op_name_.c_str()); auto holder = ge::ComGraphMakeUnique(); GE_ASSERT_NOTNULL(holder, "Create ContextHolderImpl failed."); GE_ASSERT_SUCCESS(CreateComputeNodeInfo(*holder), "Create compute node info failed."); - std::vector> tmp_outputs; - for (size_t i = 0U; i < op_info_.input_instance_num; ++i) { - input_values_.emplace_back(std::make_pair( - ge::ValueToPtr(MutableOpInfo().input_tensor_descs[i].dtype), nullptr)); + GE_ASSERT_SUCCESS(BuildCtx(*holder), "BuildCtx failed."); + auto origin_context = holder->GetContext(); + if (use_data_type_ptr_) { + for (size_t i = 0U; i < input_values_.size(); ++i) { + memcpy_s(origin_context->MutableInputPointer(i), sizeof(void *), input_values_[i].first, + sizeof(ge::DataType)); + } } - for (size_t i = 0U; i < op_info_.output_instance_num; ++i) { - output_values_.emplace_back(std::make_pair(ge::ValueToPtr(ge::DT_MAX), nullptr)); + for (size_t i = 0U; i < output_values_.size(); ++i) { + memcpy_s(origin_context->GetOutputPointer(i), sizeof(void *), output_values_[i].first, + sizeof(ge::DataType)); } - GE_ASSERT_SUCCESS(BuildCtx(*holder), "BuildCtx failed."); return holder; } -}; -static_assert(sizeof(OpInferDataTypeContextBuilderImpl) == sizeof(ContextBuilderImpl), "OpInferDataTypeContextBuilderImpl size error"); + void Inputs(const std::vector &inputs) { + if (!use_data_type_ptr_) { + input_values_.clear(); + } + for (auto i : inputs) { + input_values_.emplace_back(std::make_pair(i, nullptr)); + } + use_data_type_ptr_ = true; + } + + void Inputs(const std::vector &inputs) { + if (use_data_type_ptr_) { + input_values_.clear(); + } + for (auto i : inputs) { + input_values_.emplace_back(std::make_pair(ge::ValueToPtr(i), nullptr)); + } + use_data_type_ptr_ = false; + } + + void Outputs(const std::vector &outputs) { + for (auto output : outputs) { + output_values_.emplace_back(std::make_pair(output, nullptr)); + } + } + bool use_data_type_ptr_ = false; // Flag to indicate if input values are DataType pointers or values +}; OpInferDataTypeContextBuilder::OpInferDataTypeContextBuilder() - : OpContextBuilderBase() { - impl_ = ge::ComGraphMakeUnique(); -} + : impl_(ge::ComGraphMakeUnique()) {} +OpInferDataTypeContextBuilder::~OpInferDataTypeContextBuilder() {}; -OpInferDataTypeContextBuilder::~OpInferDataTypeContextBuilder() = default; +OpInferDataTypeContextBuilder &OpInferDataTypeContextBuilder::Inputs(const std::vector &inputs) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + static_cast(impl_.get())->Inputs(inputs); + return *this; +} -OpInferDataTypeContextBuilder &OpInferDataTypeContextBuilder::InputTensorDesc(size_t index, ge::DataType dtype, - ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type) { +OpInferDataTypeContextBuilder &OpInferDataTypeContextBuilder::Inputs(const std::vector &inputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - MutableInputDataType(index) = dtype; - MutableInputOriginalFormat(index) = origin_format; - MutableInputStorageFormat(index) = storage_format; - MutableInputExpandDimsType(index) = expand_dims_type; + static_cast(impl_.get())->Inputs(inputs); return *this; } -OpInferDataTypeContextBuilder &OpInferDataTypeContextBuilder::OutputTensorDesc(size_t index, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type) { +OpInferDataTypeContextBuilder &OpInferDataTypeContextBuilder::Outputs(const std::vector &outputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - MutableOutputOriginalFormat(index) = origin_format; - MutableOutputStorageFormat(index) = storage_format; - MutableOutputExpandDimsType(index) = expand_dims_type; + static_cast(impl_.get())->Outputs(outputs); return *this; } @@ -76,4 +102,10 @@ ContextHolder OpInferDataTypeContextBuilder::Build() { return ContextHolder(std::move(holder_void)); } +OpInfo &OpInferDataTypeContextBuilder::MutableOpInfo() { + static OpInfo null_op_info; + GE_CHECK_NOTNULL_EXEC(impl_, return null_op_info); + return impl_->GetOpInfo(); +} + } // namespace gert diff --git a/base/context_builder/op_infer_shape_context_builder.cc b/base/context_builder/op_infer_shape_context_builder.cc index 003a9eacca..936a52f305 100644 --- a/base/context_builder/op_infer_shape_context_builder.cc +++ b/base/context_builder/op_infer_shape_context_builder.cc @@ -10,76 +10,66 @@ #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/context_holder_builder.h" #include "common/ge_common/util.h" -#include "op_info.h" +#include "base/context_builder/op_info.h" +#include "base/context_builder/op_info_impl.h" #include #include "graph/debug/ge_util.h" namespace gert { -class OpInferShapeContextBuilderImpl : public ContextBuilderImpl { +class OpOpInferShapeContextBuilderImpl : public ContextBuilderImpl { public: - OpInferShapeContextBuilderImpl() : ContextBuilderImpl() {} - ~OpInferShapeContextBuilderImpl() override = default; + OpOpInferShapeContextBuilderImpl() : ContextBuilderImpl() {} + ~OpOpInferShapeContextBuilderImpl() override = default; std::unique_ptr BuildInferShapeContext() { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "op_info_impl is nullptr"); + GE_ASSERT_TRUE(op_info_impl->CheckParams(), "CheckParams failed, op_type: %s, op_name: %s", + op_info_impl->op_type_.c_str(), op_info_impl->op_name_.c_str()); auto holder = ge::ComGraphMakeUnique(); GE_ASSERT_NOTNULL(holder, "Create ContextHolderImpl failed."); GE_ASSERT_SUCCESS(CreateComputeNodeInfo(*holder), "Create compute node info failed."); input_values_.emplace_back(std::make_pair(nullptr, nullptr)); // FindInferShapeFunc - std::vector> tmp_outputs; - static auto shape_deleter = [](void *p) { - if (p == nullptr) { - return; - } - delete static_cast(p); - }; - for (size_t i = 0U; i < op_info_.output_instance_num; ++i) { - output_values_.emplace_back(new (std::nothrow) gert::Shape(), shape_deleter); - } GE_ASSERT_SUCCESS(BuildCtx(*holder), "BuildCtx failed."); return holder; } }; -static_assert(sizeof(OpInferShapeContextBuilderImpl) == sizeof(ContextBuilderImpl), "OpInferShapeContextBuilderImpl size error"); - OpInferShapeContextBuilder::OpInferShapeContextBuilder() - : OpContextBuilderBase() { - impl_ = ge::ComGraphMakeUnique(); -} - -OpInferShapeContextBuilder::~OpInferShapeContextBuilder() = default; + : impl_(ge::ComGraphMakeUnique()) {} +OpInferShapeContextBuilder::~OpInferShapeContextBuilder() {}; -OpInferShapeContextBuilder &OpInferShapeContextBuilder::OutputTensorDesc(size_t index, ge::DataType dtype, - ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type) { +OpInferShapeContextBuilder &OpInferShapeContextBuilder::InputTensors(const std::vector &inputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - MutableOutputDataType(index) = dtype; - MutableOutputOriginalFormat(index) = origin_format; - MutableOutputStorageFormat(index) = storage_format; - MutableOutputExpandDimsType(index) = expand_dims_type; + std::vector tmp_inputs; + for (auto input : inputs) { + tmp_inputs.emplace_back(input); + } + impl_->Inputs(std::move(tmp_inputs)); return *this; } -OpInferShapeContextBuilder &OpInferShapeContextBuilder::InputTensors(const std::vector &inputs) { +OpInferShapeContextBuilder &OpInferShapeContextBuilder::OutputShapes(const std::vector &outputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - std::vector tmp_inputs; - for (size_t i = 0; i < inputs.size(); ++i) { - MutableInputDataType(i) = inputs[i]->GetDataType(); - MutableInputOriginalFormat(i) = inputs[i]->GetOriginFormat(); - MutableInputStorageFormat(i) = inputs[i]->GetStorageFormat(); - MutableInputExpandDimsType(i) = inputs[i]->GetExpandDimsType(); - tmp_inputs.emplace_back(inputs[i]); + std::vector tmp_outputs; + for (auto output : outputs) { + tmp_outputs.emplace_back(output); } - impl_->Inputs(std::move(tmp_inputs)); + impl_->Outputs(std::move(tmp_outputs)); return *this; } ContextHolder OpInferShapeContextBuilder::Build() { GE_CHECK_NOTNULL_EXEC(impl_, return ContextHolder()); - auto ctx_holder_impl = static_cast(impl_.get())->BuildInferShapeContext(); + auto ctx_holder_impl = static_cast(impl_.get())->BuildInferShapeContext(); auto holder_void = ContextHolderBuilder::Create(std::move(ctx_holder_impl)); return ContextHolder(std::move(holder_void)); } +OpInfo &OpInferShapeContextBuilder::MutableOpInfo() { + static OpInfo null_op_info; + GE_CHECK_NOTNULL_EXEC(impl_, return null_op_info); + return impl_->GetOpInfo(); +} + } // namespace gert diff --git a/base/context_builder/op_infer_shape_range_context_builder.cc b/base/context_builder/op_infer_shape_range_context_builder.cc index d9ca3ab5a9..1cd42d23f2 100644 --- a/base/context_builder/op_infer_shape_range_context_builder.cc +++ b/base/context_builder/op_infer_shape_range_context_builder.cc @@ -11,7 +11,8 @@ #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/context_holder_builder.h" #include "common/ge_common/util.h" -#include "op_info.h" +#include "base/context_builder/op_info.h" +#include "base/context_builder/op_info_impl.h" #include #include "graph/debug/ge_util.h" @@ -22,77 +23,40 @@ class OpInferShapeRangeContextBuilderImpl : public ContextBuilderImpl { ~OpInferShapeRangeContextBuilderImpl() override = default; std::unique_ptr BuildInferShapeRangeContext() { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "op_info_impl is nullptr"); + GE_ASSERT_TRUE(op_info_impl->CheckParams(), "CheckParams failed, op_type: %s, op_name: %s", + op_info_impl->op_type_.c_str(), op_info_impl->op_name_.c_str()); auto holder = ge::ComGraphMakeUnique(); GE_ASSERT_NOTNULL(holder, "Create ContextHolderImpl failed."); GE_ASSERT_SUCCESS(CreateComputeNodeInfo(*holder), "Create compute node info failed."); - std::vector> tmp_outputs; - static auto shape_range_deleter = [](void *p) { - if (p == nullptr) { - return; - } - delete static_cast*>(p)->GetMin(); - delete static_cast*>(p)->GetMax(); - delete static_cast*>(p); - }; - for (size_t i = 0U; i < op_info_.output_instance_num; ++i) { - auto min_shape = ge::ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(min_shape, "Create min shape failed."); - auto max_shape = ge::ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(max_shape, "Create max shape failed."); - auto range = ge::ComGraphMakeUnique>(min_shape.release(), max_shape.release()); - GE_ASSERT_NOTNULL(range, "Create range failed."); - output_values_.emplace_back(range.release(), shape_range_deleter); - } GE_ASSERT_SUCCESS(BuildCtx(*holder), "BuildCtx failed."); return holder; } }; -static_assert(sizeof(OpInferShapeRangeContextBuilderImpl) == sizeof(ContextBuilderImpl), "OpInferShapeRangeContextBuilderImpl size error"); - OpInferShapeRangeContextBuilder::OpInferShapeRangeContextBuilder() - : OpContextBuilderBase() { - impl_ = ge::ComGraphMakeUnique(); -} -OpInferShapeRangeContextBuilder::~OpInferShapeRangeContextBuilder() = default; + : impl_(ge::ComGraphMakeUnique()) {} +OpInferShapeRangeContextBuilder::~OpInferShapeRangeContextBuilder() {}; -OpInferShapeRangeContextBuilder &OpInferShapeRangeContextBuilder::InputTensorsRange( +OpInferShapeRangeContextBuilder &OpInferShapeRangeContextBuilder::InputTensors( const std::vector *> &inputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); std::vector tmp_inputs; - for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i]->GetMin()->GetDataType() != inputs[i]->GetMax()->GetDataType() || - inputs[i]->GetMin()->GetFormat() != inputs[i]->GetMax()->GetFormat() || - inputs[i]->GetMin()->GetStorageFormat() != inputs[i]->GetMax()->GetStorageFormat() || - !(inputs[i]->GetMin()->GetExpandDimsType() == inputs[i]->GetMax()->GetExpandDimsType())) { - GELOGE(ge::PARAM_INVALID, - "Index %zu, Input max and min tensor data type/origin_format/storage_format are not equal, " - "min data type: %d, max data type: %d, min format: %d, max format: %d, min storage format: %d, max " - "storage format: %d, min expand dims type: %d, max expand dims type: %d", - i, inputs[i]->GetMin()->GetDataType(), inputs[i]->GetMax()->GetDataType(), - inputs[i]->GetMin()->GetFormat(), inputs[i]->GetMax()->GetFormat(), - inputs[i]->GetMin()->GetStorageFormat(), inputs[i]->GetMax()->GetStorageFormat(), - inputs[i]->GetMin()->GetExpandDimsType(), inputs[i]->GetMax()->GetExpandDimsType()); - return *this; - } - MutableInputDataType(i) = inputs[i]->GetMin()->GetDataType(); - MutableInputOriginalFormat(i) = inputs[i]->GetMin()->GetOriginFormat(); - MutableInputStorageFormat(i) = inputs[i]->GetMin()->GetStorageFormat(); - MutableInputExpandDimsType(i) = inputs[i]->GetMin()->GetExpandDimsType(); - tmp_inputs.emplace_back(inputs[i]); + for (auto input : inputs) { + tmp_inputs.emplace_back(input); } impl_->Inputs(std::move(tmp_inputs)); return *this; } -OpInferShapeRangeContextBuilder &OpInferShapeRangeContextBuilder::OutputTensorDesc(size_t index, ge::DataType dtype, - ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type) { +OpInferShapeRangeContextBuilder &OpInferShapeRangeContextBuilder::OutputShapes( + const std::vector *> &outputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - MutableOutputDataType(index) = dtype; - MutableOutputOriginalFormat(index) = origin_format; - MutableOutputStorageFormat(index) = storage_format; - MutableOutputExpandDimsType(index) = expand_dims_type; + std::vector tmp_outputs; + for (auto output : outputs) { + tmp_outputs.emplace_back(output); + } + impl_->Outputs(std::move(tmp_outputs)); return *this; } @@ -103,4 +67,10 @@ ContextHolder OpInferShapeRangeContextBuilder::Build() { return ContextHolder(std::move(holder_void)); } +OpInfo &OpInferShapeRangeContextBuilder::MutableOpInfo() { + static OpInfo null_op_info; + GE_CHECK_NOTNULL_EXEC(impl_, return null_op_info); + return impl_->GetOpInfo(); +} + } // namespace gert diff --git a/base/context_builder/op_info.cc b/base/context_builder/op_info.cc new file mode 100644 index 0000000000..ae94350222 --- /dev/null +++ b/base/context_builder/op_info.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ +#include "base/context_builder/op_info.h" +#include +#include "exe_graph/runtime/tensor.h" +#include "base/context_builder/op_info_impl.h" +#include "common/ge_common/util.h" +#include "graph/debug/ge_util.h" +namespace gert { + +OpInfo::OpInfo() : impl_(ge::ComGraphMakeShared()) {} +OpInfo::~OpInfo() = default; + +OpInfo &OpInfo::OpType(const AscendString &op_type) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetOpType(op_type); + return *this; +} + +OpInfo &OpInfo::OpName(const AscendString &op_name) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetOpName(op_name); + return *this; +} + +OpInfo &OpInfo::IONum(size_t input_num, size_t output_num) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetIONum(input_num, output_num); + return *this; +} + +OpInfo &OpInfo::IOInstanceNum(const std::vector &input_instance_num, + const std::vector &output_instance_num) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetIOInstanceNum(input_instance_num, output_instance_num); + return *this; +} +OpInfo &OpInfo::SetInputTd(size_t index, ge::DataType dtype, ge::Format origin_format, ge::Format storage_format, + const gert::StorageShape &shape) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetInputTd(index, dtype, origin_format, storage_format, shape); + return *this; +} +OpInfo &OpInfo::SetOutputTd(size_t index, ge::DataType dtype, ge::Format origin_format, ge::Format storage_format, + const gert::StorageShape &shape) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetOutputTd(index, dtype, origin_format, storage_format, shape); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, bool attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, int64_t attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, float attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, const AscendString &attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + std::string tmp(attr.GetString()); + impl_->SetAttr(attr_name, tmp); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, const std::vector &attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, const std::vector &attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, const std::vector &attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, const std::vector &attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + std::vector tmp; + for (const auto &item : attr) { + tmp.emplace_back(item.GetString()); + } + impl_->SetAttr(attr_name, tmp); + return *this; +} +OpInfo &OpInfo::Attr(const AscendString &attr_name, const std::vector> &attr) { + GE_CHECK_NOTNULL_EXEC(impl_, return *this); + impl_->SetAttr(attr_name, attr); + return *this; +} +} // namespace gert diff --git a/base/context_builder/op_info.h b/base/context_builder/op_info.h deleted file mode 100644 index 3139a38109..0000000000 --- a/base/context_builder/op_info.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_INFO_H_ -#define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_INFO_H_ -#include -#include "graph/types.h" -#include "exe_graph/runtime/storage_shape.h" -#include "exe_graph/runtime/shape.h" -#include "graph/ascend_string.h" -#include "graph/any_value.h" - -using namespace ge; -namespace gert { -struct ContextTensorDesc { - ge::DataType dtype = ge::DataType::DT_MAX; - ge::Format origin_format = ge::Format::FORMAT_MAX; - ge::Format storage_format = ge::Format::FORMAT_MAX; - gert::ExpandDimsType expand_dims_type; - gert::StorageShape storage_shape = {}; -}; - -struct OpInfo { - std::string op_type; - std::string op_name; - std::vector attrs; // 传递给context中的compute_node_info - uint32_t input_ir_num = 0U; - uint32_t input_instance_num = 0U; - uint32_t output_ir_num = 0U; - uint32_t output_instance_num = 0U; - std::vector input_instance; - std::vector output_instance; - std::vector input_tensor_descs; // 传递给context中的compute_node_info - std::vector output_tensor_descs; // 传递给context中的compute_node_info -}; -} // namespace gert -#endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_INFO_H_ \ No newline at end of file diff --git a/base/context_builder/op_info_impl.h b/base/context_builder/op_info_impl.h new file mode 100644 index 0000000000..3bd0af27bd --- /dev/null +++ b/base/context_builder/op_info_impl.h @@ -0,0 +1,131 @@ +/* Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ +#ifndef METADEF_BASE_CONTEXT_BUILDER_OP_INFO_IMPL_H_ +#define METADEF_BASE_CONTEXT_BUILDER_OP_INFO_IMPL_H_ +#include +#include +#include "types.h" +#include "graph/ascend_string.h" +#include "exe_graph/runtime/storage_shape.h" +#include "exe_graph/runtime/range.h" +#include "exe_graph/runtime/shape.h" +#include "graph/any_value.h" +#include "common/checker.h" +#include "exe_graph/runtime/tensor.h" +#include "exe_graph/runtime/kernel_context.h" +#include "base/context_builder/op_info.h" +using namespace ge; + +namespace gert { +struct ContextTensorDesc { + ge::DataType dtype_ = ge::DataType::DT_MAX; + ge::Format origin_format_ = ge::Format::FORMAT_MAX; + ge::Format storage_format_ = ge::Format::FORMAT_MAX; + gert::StorageShape shape_ = {}; +}; +struct OpInfoImpl { + OpInfoImpl() = default; + ~OpInfoImpl() = default; + + void SetOpType(const AscendString &op_type) { + op_type_ = op_type.GetString(); + } + void SetOpName(const AscendString &op_name) { + op_name_ = op_name.GetString(); + } + void SetIONum(size_t input_num, size_t output_num) { + if (!input_instance_.empty() || !output_instance_.empty()) { + GELOGW("IO has been set. Set IO Num failed!"); + return; // 已经设置过输入输出, 无需不允许再次设置 + } + input_ir_num_ = input_num; + output_ir_num_ = output_num; + input_tensor_descs_.resize(input_ir_num_); + output_tensor_descs_.resize(output_ir_num_); + input_instance_.resize(input_ir_num_, 1); + output_instance_.resize(output_ir_num_, 1); + input_instance_num_ = input_ir_num_; + output_instance_num_ = output_ir_num_; + } + void SetIOInstanceNum(const std::vector &input_instance, const std::vector &output_instance) { + input_instance_ = input_instance; + output_instance_ = output_instance; + input_ir_num_ = input_instance.size(); + output_ir_num_ = output_instance.size(); + input_instance_num_ = 0U; + output_instance_num_ = 0U; + for (const auto &num : input_instance) { + input_instance_num_ += num; + } + input_tensor_descs_.resize(input_instance_num_); + for (const auto &num : output_instance) { + output_instance_num_ += num; + } + output_tensor_descs_.resize(output_instance_num_); + } + void SetInputTd(size_t index, ge::DataType dtype, ge::Format origin_format, ge::Format storage_format, + const gert::StorageShape &shape) { + if (index >= input_tensor_descs_.size()) { + GELOGE(ge::FAILED, "Input tensor index %zu exceeds size %zu", index, input_tensor_descs_.size()); + return; + } + auto td = ContextTensorDesc(); + td.dtype_ = dtype; + td.origin_format_ = origin_format; + td.storage_format_ = storage_format; + td.shape_ = shape; + input_tensor_descs_[index] = std::move(td); + } + void SetOutputTd(size_t index, ge::DataType dtype, ge::Format origin_format, ge::Format storage_format, + const gert::StorageShape &shape) { + if (index >= output_tensor_descs_.size()) { + GELOGE(ge::FAILED, "Output tensor index %zu exceeds size %zu", index, output_tensor_descs_.size()); + return; + } + auto td = ContextTensorDesc(); + td.dtype_ = dtype; + td.origin_format_ = origin_format; + td.storage_format_ = storage_format; + td.shape_ = shape; + output_tensor_descs_[index] = std::move(td); + } + template + void SetAttr(const AscendString &attr_name, AttrTypeT attr) { + attrs_.emplace_back( + std::pair(attr_name.GetString(), ge::AnyValue::CreateFrom(attr))); + } + + bool CheckParams() const { + GE_ASSERT_TRUE((!op_type_.empty()) && (!op_name_.empty()) && (input_ir_num_ != 0) && (output_ir_num_ != 0), + "Invalid params, op_type: %s, op_name: %s, input_num: %u, output_num: %u", op_type_.c_str(), + op_name_.c_str(), input_ir_num_, output_ir_num_); + return true; + } + + std::string op_type_; + std::string op_name_; + std::vector> attrs_; // 传递给context中的compute_node_info + uint32_t input_ir_num_ = 0U; + uint32_t input_instance_num_ = 0U; + uint32_t output_ir_num_ = 0U; + uint32_t output_instance_num_ = 0U; + std::vector input_instance_; + std::vector output_instance_; + std::vector input_tensor_descs_; // 传递给context中的compute_node_info + std::vector output_tensor_descs_; // 传递给context中的compute_node_info +}; +class OpInfoHelper { + public: + static OpInfoImpl *GetPtr(const OpInfo &op_info) { + return op_info.impl_.get(); + } +}; +} // namespace gert + +#endif // METADEF_BASE_CONTEXT_BUILDER_OP_INFO_IMPL_H_ \ No newline at end of file diff --git a/base/context_builder/op_kernel_run_context_builder.cc b/base/context_builder/op_kernel_run_context_builder.cc index 1ffabe1370..3f86642a8e 100644 --- a/base/context_builder/op_kernel_run_context_builder.cc +++ b/base/context_builder/op_kernel_run_context_builder.cc @@ -10,7 +10,8 @@ #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/context_holder_builder.h" #include "common/ge_common/util.h" -#include "op_info.h" +#include "base/context_builder/op_info.h" +#include "base/context_builder/op_info_impl.h" #include #include "graph/debug/ge_util.h" @@ -21,6 +22,10 @@ class OpKernelRunContextBuilderImpl : public ContextBuilderImpl { ~OpKernelRunContextBuilderImpl() override = default; std::unique_ptr BuildKernelRunContext() { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "op_info_impl is nullptr"); + GE_ASSERT_TRUE(op_info_impl->CheckParams(), "CheckParams failed, op_type: %s, op_name: %s", + op_info_impl->op_type_.c_str(), op_info_impl->op_name_.c_str()); auto holder = ge::ComGraphMakeUnique(); GE_ASSERT_NOTNULL(holder, "Create ContextHolderImpl failed."); GE_ASSERT_SUCCESS(CreateComputeNodeInfo(*holder), "Create compute node info failed."); @@ -28,14 +33,9 @@ class OpKernelRunContextBuilderImpl : public ContextBuilderImpl { return holder; } }; -static_assert(sizeof(OpKernelRunContextBuilderImpl) == sizeof(ContextBuilderImpl), "OpKernelRunContextBuilderImpl size error"); - OpKernelRunContextBuilder::OpKernelRunContextBuilder() - : OpContextBuilderBase() { - impl_ = ge::ComGraphMakeUnique(); -} - -OpKernelRunContextBuilder::~OpKernelRunContextBuilder() = default; + : impl_(ge::ComGraphMakeUnique()) {} +OpKernelRunContextBuilder::~OpKernelRunContextBuilder() {}; OpKernelRunContextBuilder &OpKernelRunContextBuilder::Inputs(std::vector inputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); @@ -49,30 +49,6 @@ OpKernelRunContextBuilder &OpKernelRunContextBuilder::Outputs(std::vector OpKernelRunContextBuilder::Build() { GE_CHECK_NOTNULL_EXEC(impl_, return ContextHolder()); auto ctx_holder_impl = static_cast(impl_.get())->BuildKernelRunContext(); @@ -80,4 +56,10 @@ ContextHolder OpKernelRunContextBuilder::Build() { return ContextHolder(std::move(holder_void)); } +OpInfo &OpKernelRunContextBuilder::MutableOpInfo() { + static OpInfo null_op_info; + GE_CHECK_NOTNULL_EXEC(impl_, return null_op_info); + return impl_->GetOpInfo(); +} + } // namespace gert diff --git a/base/context_builder/op_tiling_context_builder.cc b/base/context_builder/op_tiling_context_builder.cc index cd4621dfc1..cd963aff4b 100644 --- a/base/context_builder/op_tiling_context_builder.cc +++ b/base/context_builder/op_tiling_context_builder.cc @@ -9,7 +9,8 @@ #include "base/context_builder/op_tiling_context_builder.h" #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/context_holder_builder.h" -#include "op_info.h" +#include "base/context_builder/op_info.h" +#include "base/context_builder/op_info_impl.h" #include "common/ge_common/util.h" #include #include "graph/debug/ge_util.h" @@ -21,6 +22,10 @@ class OpTilingContextBuilderImpl : public ContextBuilderImpl { ~OpTilingContextBuilderImpl() override = default; std::unique_ptr BuildTilingContext() { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "op_info_impl is nullptr"); + GE_ASSERT_TRUE(op_info_impl->CheckParams(), "CheckParams failed, op_type: %s, op_name: %s", + op_info_impl->op_type_.c_str(), op_info_impl->op_name_.c_str()); GE_ASSERT_NOTNULL(tiling_info_.compile_info_, "Compile info is nullptr"); GE_ASSERT_NOTNULL(tiling_info_.platform_info_, "Platform info is nullptr"); @@ -28,32 +33,26 @@ class OpTilingContextBuilderImpl : public ContextBuilderImpl { GE_ASSERT_NOTNULL(holder, "Create ContextHolderImpl failed."); GE_ASSERT_SUCCESS(CreateComputeNodeInfo(*holder), "Create compute node info failed."); for (auto& outvaue : output_values_) { - input_values_.emplace_back(outvaue.first, outvaue.second); + input_values_.emplace_back(std::make_pair(outvaue.first, outvaue.second)); } - input_values_.emplace_back(tiling_info_.compile_info_, nullptr); // TilingCompileInfo - input_values_.emplace_back(tiling_info_.platform_info_, nullptr); // PlatformInfo - input_values_.emplace_back(nullptr, nullptr); // PrepareTilingFrameworkData + input_values_.emplace_back(std::make_pair(tiling_info_.compile_info_, nullptr)); // TilingCompileInfo + input_values_.emplace_back(std::make_pair(tiling_info_.platform_info_, nullptr)); // PlatformInfo + input_values_.emplace_back(std::make_pair(nullptr, nullptr)); // PrepareTilingFrameworkData input_values_.emplace_back( - reinterpret_cast(tiling_info_.deterministic_), nullptr); // Deterministic + std::make_pair(reinterpret_cast(tiling_info_.deterministic_), nullptr)); // Deterministic output_values_.resize(TilingContext::kOutputNum); - output_values_[TilingContext::kOutputTilingData] = - std::make_pair(tiling_info_.tiling_data_.first, tiling_info_.tiling_data_.second); + output_values_[TilingContext::kOutputTilingData] = std::make_pair(tiling_info_.tiling_data_, nullptr); output_values_[TilingContext::kOutputWorkspace] = std::make_pair(tiling_info_.workspace_, nullptr); GE_ASSERT_SUCCESS(BuildCtx(*holder), "BuildCtx failed."); return holder; } }; -static_assert(sizeof(OpTilingContextBuilderImpl) == sizeof(ContextBuilderImpl), "OpTilingContextBuilderImpl size error"); - -OpTilingContextBuilder::OpTilingContextBuilder() : OpContextBuilderBase() { - impl_ = ge::ComGraphMakeUnique(); -} - -OpTilingContextBuilder::~OpTilingContextBuilder() = default; +OpTilingContextBuilder::OpTilingContextBuilder() : impl_(ge::ComGraphMakeUnique()) {} +OpTilingContextBuilder::~OpTilingContextBuilder() {}; OpTilingContextBuilder &OpTilingContextBuilder::CompileInfo(const void *compile_info) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - impl_->SetCompiledInfo(compile_info); + impl_->SetCompileInfo(compile_info); return *this; } @@ -63,31 +62,14 @@ OpTilingContextBuilder &OpTilingContextBuilder::PlatformInfo(const void *platfor return *this; } OpTilingContextBuilder &OpTilingContextBuilder::Deterministic(int32_t deterministic) { - if (deterministic != 0 && deterministic != 1) { - GELOGE(ge::PARAM_INVALID, "Deterministic value is invalid, expect 0 or 1, but got %d", deterministic); - return *this; - } GE_CHECK_NOTNULL_EXEC(impl_, return *this); impl_->SetDeterministic(deterministic); return *this; } -OpTilingContextBuilder &OpTilingContextBuilder::TilingData(const gert::TilingData *tiling_data, - gert::Chain::Deleter deleter) { +OpTilingContextBuilder &OpTilingContextBuilder::TilingData(const gert::TilingData *tiling_data) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - impl_->SetTilingData(tiling_data, deleter); - return *this; -} - -OpTilingContextBuilder &OpTilingContextBuilder::TilingDataSize(size_t tiling_data_size) { - GE_CHECK_NOTNULL_EXEC(impl_, return *this); - auto tiling_data = TilingData::CreateCap(tiling_data_size); - static auto delete_tiling_data = [] (void *data) { - if (data != nullptr) { - delete [] static_cast(data); - } - }; - impl_->SetTilingData(tiling_data.release(), delete_tiling_data); + impl_->SetTilingData(tiling_data); return *this; } @@ -100,12 +82,8 @@ OpTilingContextBuilder &OpTilingContextBuilder::Workspace(const gert::Continuous OpTilingContextBuilder &OpTilingContextBuilder::InputTensors(const std::vector &inputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); std::vector tmp_inputs; - for (size_t i = 0; i < inputs.size(); ++i) { - MutableInputDataType(i) = inputs[i]->GetDataType(); - MutableInputOriginalFormat(i) = inputs[i]->GetOriginFormat(); - MutableInputStorageFormat(i) = inputs[i]->GetStorageFormat(); - MutableInputExpandDimsType(i) = inputs[i]->GetExpandDimsType(); - tmp_inputs.emplace_back(inputs[i]); + for (auto input : inputs) { + tmp_inputs.emplace_back(input); } impl_->Inputs(std::move(tmp_inputs)); return *this; @@ -114,12 +92,8 @@ OpTilingContextBuilder &OpTilingContextBuilder::InputTensors(const std::vector &outputs) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); std::vector tmp_outputs; - for (size_t i = 0; i < outputs.size(); ++i) { - MutableOutputDataType(i) = outputs[i]->GetDataType(); - MutableOutputOriginalFormat(i) = outputs[i]->GetOriginFormat(); - MutableOutputStorageFormat(i) = outputs[i]->GetStorageFormat(); - MutableOutputExpandDimsType(i) = outputs[i]->GetExpandDimsType(); - tmp_outputs.emplace_back(outputs[i]); + for (auto output : outputs) { + tmp_outputs.emplace_back(output); } impl_->Outputs(std::move(tmp_outputs)); return *this; @@ -132,4 +106,10 @@ ContextHolder OpTilingContextBuilder::Build() { return ContextHolder(std::move(holder_void)); } +OpInfo &OpTilingContextBuilder::MutableOpInfo() { + static OpInfo null_op_info; + GE_CHECK_NOTNULL_EXEC(impl_, return null_op_info); + return impl_->GetOpInfo(); +} + } // namespace gert \ No newline at end of file diff --git a/base/context_builder/op_tiling_parse_context_builder.cc b/base/context_builder/op_tiling_parse_context_builder.cc index 9b1e2875de..13097d51e8 100644 --- a/base/context_builder/op_tiling_parse_context_builder.cc +++ b/base/context_builder/op_tiling_parse_context_builder.cc @@ -10,7 +10,8 @@ #include "base/context_builder/op_context_builder_impl.h" #include "base/context_builder/context_holder_builder.h" #include "common/ge_common/util.h" -#include "op_info.h" +#include "base/context_builder/op_info.h" +#include "base/context_builder/op_info_impl.h" #include #include "graph/debug/ge_util.h" @@ -21,6 +22,10 @@ class OpTilingParseContextBuilderImpl : public ContextBuilderImpl { ~OpTilingParseContextBuilderImpl() override = default; std::unique_ptr BuildTilingParseContext() { + auto op_info_impl = OpInfoHelper::GetPtr(op_info_); + GE_ASSERT_NOTNULL(op_info_impl, "op_info_impl is nullptr"); + GE_ASSERT_TRUE(op_info_impl->CheckParams(), "CheckParams failed, op_type: %s, op_name: %s", + op_info_impl->op_type_.c_str(), op_info_impl->op_name_.c_str()); GE_ASSERT_NOTNULL(tiling_parse_info_.compiled_json_, "Compile json is nullptr"); GE_ASSERT_NOTNULL(tiling_info_.platform_info_, "Platform info is nullptr"); output_values_.clear(); @@ -31,32 +36,28 @@ class OpTilingParseContextBuilderImpl : public ContextBuilderImpl { GE_ASSERT_SUCCESS(CreateComputeNodeInfo(*holder), "Create compute node info failed."); auto compute_node_info = ge::PtrToPtr(holder->compute_node_info_holder_.get()); input_values_.clear(); - input_values_.emplace_back(reinterpret_cast(const_cast(tiling_parse_info_.compiled_json_)), - static_cast(nullptr)); - input_values_.emplace_back(tiling_info_.platform_info_, nullptr); - input_values_.emplace_back(reinterpret_cast(const_cast(compute_node_info->GetNodeType())), - nullptr); + input_values_.emplace_back(std::make_pair( + reinterpret_cast(const_cast(tiling_parse_info_.compiled_json_)), Chain::Deleter(nullptr))); + input_values_.emplace_back(std::make_pair(tiling_info_.platform_info_, nullptr)); + input_values_.emplace_back( + std::make_pair(reinterpret_cast(const_cast(compute_node_info->GetNodeType())), nullptr)); GE_ASSERT_SUCCESS(BuildCtx(*holder), "BuildCtx failed."); return holder; } }; -static_assert(sizeof(OpTilingParseContextBuilderImpl) == sizeof(ContextBuilderImpl), - "OpTilingParseContextBuilderImpl size error"); -OpTilingParseContextBuilder::OpTilingParseContextBuilder() : OpContextBuilderBase() { - impl_ = ge::ComGraphMakeUnique(); -} - -OpTilingParseContextBuilder::~OpTilingParseContextBuilder() = default; +OpTilingParseContextBuilder::OpTilingParseContextBuilder() + : impl_(ge::ComGraphMakeUnique()) {} +OpTilingParseContextBuilder::~OpTilingParseContextBuilder() {}; OpTilingParseContextBuilder &OpTilingParseContextBuilder::CompiledJson(const ge::char_t *compiled_json) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); impl_->SetCompiledJson(compiled_json); return *this; } -OpTilingParseContextBuilder &OpTilingParseContextBuilder::CompiledInfo(const void *compile_info) { +OpTilingParseContextBuilder &OpTilingParseContextBuilder::CompileInfo(const void *compile_info) { GE_CHECK_NOTNULL_EXEC(impl_, return *this); - impl_->SetCompiledInfo(compile_info); + impl_->SetCompileInfo(compile_info); return *this; } OpTilingParseContextBuilder &OpTilingParseContextBuilder::PlatformInfo(const void *platform_info) { @@ -72,28 +73,10 @@ ContextHolder OpTilingParseContextBuilder::Build() { return ContextHolder(std::move(holder_void)); } -OpTilingParseContextBuilder &OpTilingParseContextBuilder::InputTensorDesc(size_t index, ge::DataType dtype, - ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type) { - GE_CHECK_NOTNULL_EXEC(impl_, return *this); - MutableInputDataType(index) = dtype; - MutableInputOriginalFormat(index) = origin_format; - MutableInputStorageFormat(index) = storage_format; - MutableInputExpandDimsType(index) = expand_dims_type; - return *this; -} - -OpTilingParseContextBuilder &OpTilingParseContextBuilder::OutputTensorDesc(size_t index, ge::DataType dtype, - ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type) { - GE_CHECK_NOTNULL_EXEC(impl_, return *this); - MutableOutputDataType(index) = dtype; - MutableOutputOriginalFormat(index) = origin_format; - MutableOutputStorageFormat(index) = storage_format; - MutableOutputExpandDimsType(index) = expand_dims_type; - return *this; +OpInfo &OpTilingParseContextBuilder::MutableOpInfo() { + static OpInfo null_op_info; + GE_CHECK_NOTNULL_EXEC(impl_, return null_op_info); + return impl_->GetOpInfo(); } } // namespace gert \ No newline at end of file diff --git a/inc/base/attr/attrs_to_buffer.h b/inc/base/attr/attrs_to_buffer.h index 73822dd3ae..357ad7e490 100644 --- a/inc/base/attr/attrs_to_buffer.h +++ b/inc/base/attr/attrs_to_buffer.h @@ -263,11 +263,11 @@ std::unique_ptr CreateAttrBuffer(const std::vector CreateAttrBufferWithAttrs(const std::vector &attrs, size_t &size) { +std::unique_ptr CreateAttrBufferWithAttrs(const std::vector> &attrs, + size_t &size) { std::vector> runtime_attrs; for (auto &attr : attrs) { - AppendAttr(attr, runtime_attrs); + AppendAttr(attr.second, runtime_attrs); } return CreateAttrBuffer(runtime_attrs, size); } diff --git a/inc/external/base/context_builder/context_holder.h b/inc/external/base/context_builder/context_holder.h index 5dd5503f93..a6604ad17c 100644 --- a/inc/external/base/context_builder/context_holder.h +++ b/inc/external/base/context_builder/context_holder.h @@ -17,9 +17,9 @@ class ContextHolderVoid { public: ContextHolderVoid(); ~ContextHolderVoid(); - ContextHolderVoid(ContextHolderVoid&& other) noexcept; - ContextHolderVoid& operator=(ContextHolderVoid&& other) noexcept; - void *GetContext() const; + ContextHolderVoid(ContextHolderVoid&& other); + ContextHolderVoid& operator=(ContextHolderVoid&& other); + void *GetContext(); private: friend class ContextHolderBuilder; std::unique_ptr ctx_holder_impl_; @@ -31,7 +31,7 @@ class ContextHolderVoid { template class ContextHolder { public: - ContextHolder() = default; + ContextHolder() {} explicit ContextHolder(ContextHolderVoid &&holder_void) : holder_void_(std::move(holder_void)) {} /** * @brief 按指定类型获取ctx指针 diff --git a/inc/external/base/context_builder/op_context_builder_base.h b/inc/external/base/context_builder/op_context_builder_base.h deleted file mode 100644 index 4648ca9da0..0000000000 --- a/inc/external/base/context_builder/op_context_builder_base.h +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_CONTEXT_BUILDER_BASE_H_ -#define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_CONTEXT_BUILDER_BASE_H_ -#include -#include "graph/types.h" -#include "exe_graph/runtime/storage_shape.h" -#include "exe_graph/runtime/shape.h" -#include "graph/ascend_string.h" -#include "exe_graph/runtime//expand_dims_type.h" -namespace gert { -class ContextBuilderImpl; -/** - * @brief OpContextBuilderBase基类,用于构造Op 子类context中算子信息,包括算子类型、名称、输入输出原型个数、输入输出实例个数、属性等信息。 - * 注意:不可单独构造OpContextBuilderBase基类对象,只能通过子类构造 - * @param T 子类类型,用于返回子类对象的引用,用于支持子类链式调用 -*/ -template -class OpContextBuilderBase { - public: - /** - * @brief 设置OpType,用作构造各子类context的基础ComputeNodeInfo信息 - * @param op_type Op的类型 - * @return 返回子类对象T类型的引用,用于支持子类链式调用 - */ - T &OpType(const ge::AscendString &op_type); - /** - * @brief 设置OpName,用作构造各子类context的基础ComputeNodeInfo信息 - * @param op_name Op的名称 - * @return 返回子类对象T类型的引用,用于支持子类链式调用 - */ - T &OpName(const ge::AscendString &op_name); - /** - * @brief 设置Op输入输出IR原型个数,用作构造各子类context的基础ComputeNodeInfo信息, 默认每个IR原型输入输出的实例个数为1 - * @param input_ir_num 输入IR原型个数 - * @param output_ir_num 输出IR原型个数 - * @attention 此接口与IOInstanceNum接口互斥。仅需调用2种接口的一种即可。 - * @return 返回子类对象T类型的引用,用于支持子类链式调用 - */ - T &IONum(size_t input_ir_num, size_t output_ir_num); - /** - * @brief 当输入IR原型实例个数不为1时(一般是可选输入或动态输入场景),需要设置Op每个输入IR原型的实例个数, - * 用作构造各子类context的基础ComputeNodeInfo信息 - * @note 当算子存在dynamic input类型输入时,对应input的instance_num需设置为大于1的值 - * @param input_instance_num 每个IR原型输入的实例个数 - * @param output_instance_num 每个IR原型输出的实例个数 - * @attention 此接口与IONum接口互斥。仅需调用2种接口的一种即可。 - * @return 返回子类对象T类型的引用,用于支持子类链式调用 - */ - T &IOInstanceNum(const std::vector &input_instance_num, const std::vector &output_instance_num); - - /** - * @brief 往后追加Op IR原型的属性信息,下标从0开始,用作构造各子类context的基础ExtendedInfo里通过GetAttr接口获取到的的RuntimeAttr属性信息 - * @note 请注意,往后追加的属性,获取到的属性是一个有序列表,属性构造的顺序与通过Context的基类接口GetAttr获取到的RuntimeAttrs中属性的顺序一致. - * 例如:context_builder.AppendAttr(bool attr0).AppendAttr(int64_t attr1).AppendAttr(vector attr2),则 - * ctx->GetAttrs()->GetBool(0) -> attr0, - * ctx->GetAttrs()->GetInt(1) -> attr1, - * ctx->GetAttrs()->GetListInt(2) -> attr2 - * @param attr 属性值,当前仅支持以下确定的几种类型:bool、int64_t、float、AscendString、std::vector、 - * std::vector、std::vector、std::vector、std::vector> - * @return 返回子类对象T类型的引用,用于支持子类链式调用 - */ - T &AppendAttr(bool attr); - T &AppendAttr(int64_t attr); - T &AppendAttr(float attr); - T &AppendAttr(const ge::AscendString &attr); - T &AppendAttr(const std::vector &attr); - T &AppendAttr(const std::vector &attr); - T &AppendAttr(const std::vector &attr); - T &AppendAttr(const std::vector &attr); - T &AppendAttr(const std::vector> &attr); - - // 禁止拷贝和赋值 - OpContextBuilderBase(const OpContextBuilderBase&) = delete; - OpContextBuilderBase& operator=(const OpContextBuilderBase&) = delete; - // 禁止移动构造和移动赋值 - OpContextBuilderBase(OpContextBuilderBase&&) = delete; - OpContextBuilderBase& operator=(OpContextBuilderBase&&) = delete; - - virtual ~OpContextBuilderBase(); - -protected: - [[nodiscard]] ge::DataType &MutableInputDataType(size_t index); - [[nodiscard]] ge::Format &MutableInputOriginalFormat(size_t index); - [[nodiscard]] ge::Format &MutableInputStorageFormat(size_t index); - [[nodiscard]] gert::ExpandDimsType &MutableInputExpandDimsType(size_t index); - - [[nodiscard]] ge::DataType &MutableOutputDataType(size_t index); - [[nodiscard]] ge::Format &MutableOutputOriginalFormat(size_t index); - [[nodiscard]] ge::Format &MutableOutputStorageFormat(size_t index); - [[nodiscard]] gert::ExpandDimsType &MutableOutputExpandDimsType(size_t index); - std::unique_ptr impl_; - - OpContextBuilderBase(); -}; -} // namespace gert -#endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_CONTEXT_BUILDER_BASE_H_ diff --git a/inc/external/base/context_builder/op_infer_datatype_context_builder.h b/inc/external/base/context_builder/op_infer_datatype_context_builder.h index 4d3540a251..275d4a9f8d 100644 --- a/inc/external/base/context_builder/op_infer_datatype_context_builder.h +++ b/inc/external/base/context_builder/op_infer_datatype_context_builder.h @@ -9,56 +9,53 @@ #ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_DTYPE_CTX_BUILDER_H_ #define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_DTYPE_CTX_BUILDER_H_ - +#include #include #include "graph/types.h" +#include "base/context_builder/op_info.h" #include "base/context_builder/context_holder.h" #include "exe_graph/runtime/infer_datatype_context.h" -#include "base/context_builder/op_context_builder_base.h" - namespace gert { -/** - * @brief OpInferDataTypeContextBuilder类,用于构造InferDataTypeContext. - * @note OpInferDataTypeContextBuilder类的实例化对象用于构造算子数据类型推导的执行上下文。 -*/ -class OpInferDataTypeContextBuilder : public OpContextBuilderBase { +class ContextBuilderImpl; +class OpInferDataTypeContextBuilder { public: OpInferDataTypeContextBuilder(); - ~OpInferDataTypeContextBuilder() override; - + ~OpInferDataTypeContextBuilder(); /** - * @brief 设置第index个实例输入的Tensor Description信息, - * 用于构造InferDataTypeContext的基类ExtendedKernelContext中的ComputeNodeInfo信息 - * @param index 输入的索引,对应的是Op IR原型中的的输入实例Instance索引 - * @param dtype 输入Tensor的data type - * @param origin_format 输入Tensor的原始格式 - * @param storage_format 输入Tensor的存储格式 - * @param expand_dims_type 输入Tensor的ExpandDimsType,默认值为{} - * @return OpInferDataTypeContextBuilder对象引用,用于链式调用 - */ - OpInferDataTypeContextBuilder &InputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); - + * @brief 获取当前Op的信息 + * @return OpInfo对象的引用 + */ + OpInfo &MutableOpInfo(); /** - * @brief 设置第index个实例输出的Tensor Description信息, - * 用于构造InferDataTypeContext的基类ExtendedKernelContext中的ComputeNodeInfo信息, - * 无需设置输出data type信息,输出data type由算子实现类根据输入DataType计算推导得到 - * @param index 输出的索引,对应的是Op IR原型中的的输出实例Instance索引 - * @param origin_format 输出Tensor的原始格式 - * @param storage_format 输出Tensor的存储格式 - * @param expand_dims_type 输出Tensor的ExpandDimsType,默认值为{} - * @return OpInferDataTypeContextBuilder对象引用,用于链式调用 + * @brief 设置context的input_values,values承载的类型为ge::DataType的值数组 + * @note 建议使用此接口,输入指针接口是用作兼容性设计,设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param inputs 输入值数组 + * @return InferDatatypeContextBuilder对象用于链式调用 */ - OpInferDataTypeContextBuilder &OutputTensorDesc(size_t index, ge::Format origin_format, ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); + OpInferDataTypeContextBuilder &Inputs(const std::vector &inputs); + /** + * @brief 设置context的input_values,values承载的类型为ge::DataType*的指针数组 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param inputs 输入指针数组 + * @return InferDatatypeContextBuilder对象用于链式调用 + */ + OpInferDataTypeContextBuilder &Inputs(const std::vector &inputs); + /** + * @brief 设置context的output_values,values承载的类型为ge::DataType*的指针数组 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param outputs 输出指针数组 + * @return InferDataTypeContextBuilder对象用于链式调用 + */ + OpInferDataTypeContextBuilder &Outputs(const std::vector &outputs); + /** * @brief 构建InferDataTypeContext对象 - * @return 返回一个ContextHolder对象,包含InferDataTypeContext指针, - * 注意返回的ContextHolder对象的生命周期需要大于等于InferDataTypeContext对象的生命周期, - * 才能保证通过InferDataTypeContext获取的所有指针的有效性 + * @return 返回一个ContextHolder对象,包含InferDataTypeContext指针 */ ContextHolder Build(); + + private: + std::unique_ptr impl_; }; } // namespace gert #endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_DTYPE_CTX_BUILDER_H_ diff --git a/inc/external/base/context_builder/op_infer_shape_context_builder.h b/inc/external/base/context_builder/op_infer_shape_context_builder.h index c795cf4bc9..5f6b4ace09 100644 --- a/inc/external/base/context_builder/op_infer_shape_context_builder.h +++ b/inc/external/base/context_builder/op_infer_shape_context_builder.h @@ -9,50 +9,49 @@ #ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_SHAPE_CTX_BUILDER_H_ #define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_SHAPE_CTX_BUILDER_H_ - +#include #include -#include "base/context_builder/op_context_builder_base.h" +#include "base/context_builder/op_info.h" #include "base/context_builder/context_holder.h" #include "exe_graph/runtime/tensor.h" #include "exe_graph/runtime/infer_shape_context.h" namespace gert { -/** - * @brief OpInferShapeContextBuilder类,用于构造InferShapeContext. - * @note OpInferShapeContextBuilder类的实例化对象用于构造算子形状推导的执行上下文。 -*/ -class OpInferShapeContextBuilder : public OpContextBuilderBase { +class ContextBuilderImpl; +class OpInferShapeContextBuilder { public: OpInferShapeContextBuilder(); - ~OpInferShapeContextBuilder() override; - + ~OpInferShapeContextBuilder(); /** - * @brief 用作构造InferShapeContext时Op输出的Tensor Description信息,用于构造 - * InferShapeContext的基类ExtendedKernelContext中的ComputeNodeInfo等信息 - * @param index 输出的索引,对应的是Op IR原型中的的输出实例Instance索引 - * @param dtype 输出Tensor的数据类型 - * @param origin_format 输出Tensor的原始格式 - * @param storage_format 输出Tensor的存储格式 - * @param expand_dims_type 输出Tensor的ExpandDimsType - * @return OpInferShapeContextBuilder对象用于链式调用 - */ - OpInferShapeContextBuilder &OutputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, const gert::ExpandDimsType &expand_dims_type = {}); - + * @brief 获取当前Op的信息 + * @return OpInfo对象的引用 + */ + OpInfo &MutableOpInfo(); /** - * @brief 设置输入Tensor指针,用于在shape推导时,可通过该builder类构造的上下文InferShapeContext获取相应的输入tensor指针 - * @note 对于数据依赖的算子,对应数据依赖的输入Tensor中的TensorData是需要有Host地址的正确值;对于非数据依赖算子,Tensor的TensorData为空指针 - * @param inputs 输入指针数组,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 - * @return OpInferShapeContextBuilder对象用于链式调用 + * @brief 设置context的input_values,values承载的类型为gert::Tensor*的指针数组 + * @note 对于数据依赖的算子,Tensor中的TensorData是需要有Host地址的正确值;对于非数据依赖算子,Tensor的TensorData应该为空指针 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param inputs 输入指针数组 + * @return InferShapeContextBuilder对象用于链式调用 */ OpInferShapeContextBuilder &InputTensors(const std::vector &inputs); + /** + * @brief 设置context的output_values,values承载的类型为gert::StorageShape*的指针数组 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param outputs 输出指针数组 + * @return InferShapeContextBuilder对象用于链式调用 + */ + OpInferShapeContextBuilder &OutputShapes(const std::vector &outputs); + /** * @brief 构建InferShapeContext对象 - * @return 返回一个ContextHolder对象,包含InferShapeContext指针, - * 注意返回的ContextHolder对象的生命周期需要大于等于InferShapeContext对象的生命周期,才能保证通过InferShapeContext获取的所有指针的有效性 + * @return 返回一个ContextHolder对象,包含InferShapeContext指针 */ ContextHolder Build(); + + private: + std::unique_ptr impl_; }; } // namespace gert #endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_SHAPE_CTX_BUILDER_H_ diff --git a/inc/external/base/context_builder/op_infer_shape_range_context_builder.h b/inc/external/base/context_builder/op_infer_shape_range_context_builder.h index 988bac5cf7..d31667b634 100644 --- a/inc/external/base/context_builder/op_infer_shape_range_context_builder.h +++ b/inc/external/base/context_builder/op_infer_shape_range_context_builder.h @@ -9,53 +9,47 @@ #ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_SHAPERANGE_CTX_BUILDER_H_ #define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_SHAPERANGE_CTX_BUILDER_H_ - +#include #include +#include "base/context_builder/op_info.h" #include "base/context_builder/context_holder.h" #include "exe_graph/runtime/infer_shape_range_context.h" -#include "base/context_builder/op_context_builder_base.h" namespace gert { -/** - * @brief OpInferShapeRangeContextBuilder类,用于构造InferShapeRangeContext. - * @note OpInferShapeRangeContextBuilder类的实例化对象用于构造算子shape range推导的执行上下文。 -*/ -class OpInferShapeRangeContextBuilder : public OpContextBuilderBase { +class ContextBuilderImpl; +class OpInferShapeRangeContextBuilder { public: OpInferShapeRangeContextBuilder(); - ~OpInferShapeRangeContextBuilder() override; - + ~OpInferShapeRangeContextBuilder(); /** - * @brief 用作构造InferShapeRangeContext时Op输出的Tensor Description信息,用于构造InferShapeRangeContext的基类 - * ExtendedKernelContext中的ComputeNodeInfo信息 - * @param index 输出的索引,对应的是Op IR原型中的的输出实例Instance索引 - * @param dtype 输出Tensor的数据类型 - * @param origin_format 输出Tensor的原始格式 - * @param storage_format 输出Tensor的存储格式 - * @param expand_dims_type 输出Tensor的ExpandDimsType信息 - * @return OpInferShapeRangeContextBuilder对象用于链式调用 + * @brief 获取当前Op的信息 + * @return OpInfo对象的引用 */ - OpInferShapeRangeContextBuilder &OutputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); + OpInfo &MutableOpInfo(); + /** + * @brief 设置context的input_values,values承载的类型为gert::Range*的指针数组 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param inputs 输入指针数组 + * @return InferShapeRangeContextBuilder对象用于链式调用 + */ + OpInferShapeRangeContextBuilder &InputTensors(const std::vector *> &inputs); /** - * @brief 设置输入Tensor range指针,用于在shape range推导时,可通过该builder类构造的上下文InferShapeRangeContext获取相应的输入tensor range指针 - * 即可以获得Max Tensor range和Min Tensor range - * @note 对于数据依赖的算子,对应数据依赖的输入Tensor中的TensorData是需要有Host地址的正确值;对于非数据依赖算子,Tensor的TensorData为空指针 - * @note 调用此接口时,输入的Range的Min Tensor跟Max Tensor需要保证DataType、OriginFormat、StorageFormat一致 - * @param inputs 输入指针数组,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 - * @return OpInferShapeRangeContextBuilder对象用于链式调用 + * @brief 设置context的output_values,values承载的类型为gert::Range*的指针数组 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param outputs 输出指针数组 + * @return InferShapeRangeContextBuilder对象用于链式调用 */ - OpInferShapeRangeContextBuilder &InputTensorsRange(const std::vector *> &inputs); + OpInferShapeRangeContextBuilder &OutputShapes(const std::vector *> &outputs); /** * @brief 构建InferShapeRangeContext对象 - * @return 返回一个ContextHolder对象,包含InferShapeRangeContext指针, - * @note 返回的ContextHolder对象的生命周期需要大于等于InferShapeRangeContext对象的生命周期, - * 才能保证通过InferShapeRangeContext获取的所有指针的有效性 + * @return 返回一个ContextHolder对象,包含InferShapeRangeContext指针 */ ContextHolder Build(); + + private: + std::unique_ptr impl_; }; } // namespace gert #endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_INFER_SHAPERANGE_CTX_BUILDER_H_ diff --git a/inc/external/base/context_builder/op_info.h b/inc/external/base/context_builder/op_info.h new file mode 100644 index 0000000000..6b991a7b6b --- /dev/null +++ b/inc/external/base/context_builder/op_info.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_INFO_H_ +#define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_INFO_H_ +#include +#include "graph/types.h" +#include "exe_graph/runtime/storage_shape.h" +#include "exe_graph/runtime/shape.h" +#include "graph/ascend_string.h" +using namespace ge; +namespace gert { +struct OpInfoImpl; +class OpInfoBuilder; +/** + * @brief OpInfo类用于描述算子信息,包括算子类型、名称、输入输出个数、输入输出Tensor信息等。 + * @note OpInfo类的实例化对象用于构造算子执行上下文 +*/ +class OpInfo { + public: + OpInfo(); + ~OpInfo(); + /** + * @brief 设置OpType,用作构造context的输入 + * @param op_type Op的类型 + * @return OpInfo对象用于链式调用 + */ + OpInfo &OpType(const AscendString &op_type); + /** + * @brief 设置OpName,用作构造context的输入 + * @param op_name Op的名称 + * @return OpInfo对象用于链式调用 + */ + OpInfo &OpName(const AscendString &op_name); + /** + * @brief 设置Op输入输出个数,用作构造context的输入, 默认每个输入输出的实例个数为1 + * @param input_num 输入个数 + * @param output_num 输出个数 + * @attention 此接口与IOInstanceNum接口互斥。仅需调用2种接口的一种即可。 + * @return OpInfo对象用于链式调用 + */ + OpInfo &IONum(size_t input_num, size_t output_num); + /** + * @brief 当输入实例个数不为1时,需要设置Op每个输入的实例个数,用作构造context的输入。 + * @note 当算子存在dynamic input类型输入时,对应input的instance_num可以设置为大于1的值 + * @param input_instance_num 每个输入的实例个数 + * @param output_instance_num 每个输出的实例个数 + * @attention 此接口与IONum接口互斥。仅需调用2种接口的一种即可。 + * @return OpInfo对象用于链式调用 + */ + OpInfo &IOInstanceNum(const std::vector &input_instance_num, const std::vector &output_instance_num); + /** + * @brief 用作构造KernelRunContext,InferShapeContext,InferDataTypeContext,InferShapeRangeContext时设置Op输入的Tensor信息 + * @param index 输入的索引 + * @param dtype 输入Tensor的数据类型 + * @param origin_format 输入Tensor的原始格式 + * @param storage_format 输入Tensor的存储格式 + * @param shape 输入Tensor的shape + * @return OpInfo对象用于链式调用 + */ + OpInfo &SetInputTd(size_t index, ge::DataType dtype, ge::Format origin_format, ge::Format storage_format, + const gert::StorageShape &shape = {}); + /** + * @brief 用作构造KernelRunContext,InferShapeContext,InferDataTypeContext,InferShapeRangeContext时设置Op输出的Tensor信息 + * @param index 输出的索引 + * @param dtype 输出Tensor的数据类型 + * @param origin_format 输出Tensor的原始格式 + * @param storage_format 输出Tensor的存储格式 + * @param shape 输出Tensor的shape + * @return OpInfo对象用于链式调用 + */ + OpInfo &SetOutputTd(size_t index, ge::DataType dtype, ge::Format origin_format, ge::Format storage_format, + const gert::StorageShape &shape = {}); + /** + * @brief 设置Op的属性信息,用作构造context的输入 + * @param attr_name 属性名称 + * @param attr 属性值 + * @return OpInfo对象用于链式调用 + */ + OpInfo &Attr(const AscendString &attr_name, bool attr); + OpInfo &Attr(const AscendString &attr_name, int64_t attr); + OpInfo &Attr(const AscendString &attr_name, float attr); + OpInfo &Attr(const AscendString &attr_name, const AscendString &attr); + OpInfo &Attr(const AscendString &attr_name, const std::vector &attr); + OpInfo &Attr(const AscendString &attr_name, const std::vector &attr); + OpInfo &Attr(const AscendString &attr_name, const std::vector &attr); + OpInfo &Attr(const AscendString &attr_name, const std::vector &attr); + OpInfo &Attr(const AscendString &attr_name, const std::vector> &attr); + + private: + std::shared_ptr impl_; + friend class OpInfoHelper; +}; +} // namespace gert +#endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_OP_INFO_H_ \ No newline at end of file diff --git a/inc/external/base/context_builder/op_kernel_run_context_builder.h b/inc/external/base/context_builder/op_kernel_run_context_builder.h index 7244e315bc..088fe1d712 100644 --- a/inc/external/base/context_builder/op_kernel_run_context_builder.h +++ b/inc/external/base/context_builder/op_kernel_run_context_builder.h @@ -11,51 +11,26 @@ #define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_KERNEL_CONTEXT_BUILDER_H_ #include #include +#include "base/context_builder/op_info.h" #include "base/context_builder/context_holder.h" #include "exe_graph/runtime/kernel_run_context.h" -#include "base/context_builder/op_context_builder_base.h" -#include "exe_graph//runtime/expand_dims_type.h" namespace gert { -/** - * @brief OpKernelRunContextBuilder类,用于构造KernelContext. - * @note OpKernelRunContextBuilder类的实例化对象用于构造算子host上执行相关交付件的上下文。 - */ -class OpKernelRunContextBuilder : public OpContextBuilderBase { +class ContextBuilderImpl; +class OpKernelRunContextBuilder { public: OpKernelRunContextBuilder(); - ~OpKernelRunContextBuilder() override; + ~OpKernelRunContextBuilder(); /** - * @brief 设置第index个实例输入的Tensor Description信息,用于构造 - * KernelRunContext的基类ExtendedKernelContext中的ComputeNodeInfo信息 - * @param index 输入的索引,对应的是Op IR原型中的的输入实例Instance索引 - * @param dtype 输入Tensor的data type - * @param origin_format 输入Tensor的原始格式 - * @param storage_format 输入Tensor的存储格式 - * @param expand_dims_type 输入Tensor的ExpandDimsType,默认值为{} - * @return OpKernelRunContextBuilder对象引用,用于链式调用 - */ - OpKernelRunContextBuilder &InputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); - /** - * @brief 设置第index个实例输出的Tensor Description信息,用于构造 - * KernelRunContext的基类ExtendedKernelContext中的ComputeNodeInfo信息 - * @param index 输出的索引,对应的是Op IR原型中的的输出实例Instance索引 - * @param dtype 输出Tensor的data type - * @param origin_format 输出Tensor的原始格式 - * @param storage_format 输出Tensor的存储格式 - * @param expand_dims_type 输出Tensor的ExpandDimsType,默认值为{} - * @return OpKernelRunContextBuilder对象引用,用于链式调用 + * @brief 获取当前Op的信息 + * @return OpInfo对象的引用 */ - OpKernelRunContextBuilder &OutputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); - + OpInfo &MutableOpInfo(); /** * @brief 设置context的values的输入指针,values承载的类型为void*的指针数组 - * @param inputs 输入指针数组,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 - * @return OpKernelRunContextBuilder对象引用,用于链式调用 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param inputs 输入指针数组 + * @return KernelRunContextBuilder对象用于链式调用 */ OpKernelRunContextBuilder &Inputs(std::vector inputs); @@ -63,17 +38,18 @@ class OpKernelRunContextBuilder : public OpContextBuilderBase outputs); /** * @brief 构建KernelRunContext对象 - * @return 返回一个ContextHolder对象,包含KernelRunContext指针, - * @note 注意返回的ContextHolder对象的生命周期需要大于等于KernelRunContext对象的生命周期, - * 才能保证通过KernelRunContext获取的所有指针的有效性 + * @return 返回一个ContextHolder对象,包含KernelRunContext指针 */ ContextHolder Build(); + + private: + std::unique_ptr impl_; }; } // namespace gert #endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_KERNEL_CONTEXT_BUILDER_H_ diff --git a/inc/external/base/context_builder/op_tiling_context_builder.h b/inc/external/base/context_builder/op_tiling_context_builder.h index 2dc37d8c06..4da8ceb199 100644 --- a/inc/external/base/context_builder/op_tiling_context_builder.h +++ b/inc/external/base/context_builder/op_tiling_context_builder.h @@ -12,20 +12,25 @@ #include #include #include "base/context_builder/context_holder.h" +#include "base/context_builder/op_info.h" #include "exe_graph/runtime/tiling_context.h" -#include "base/context_builder/op_context_builder_base.h" namespace gert { +class ContextBuilderImpl; class TilingData; /** - * @brief OpTilingContextBuilder类,用于构造TilingContext. - * @note OpTilingContextBuilder类的实例化对象用于构造算子tiling的执行上下文。 -*/ -class OpTilingContextBuilder : public OpContextBuilderBase { + * @brief TilingContextBuilder类用于构建算子执行上下文,包含Op信息、编译信息、平台信息、tiling数据、workspace内存等。 + * @note TilingContextBuilder类的实例化对象用于构造算子执行上下文。 + */ +class OpTilingContextBuilder { public: OpTilingContextBuilder(); - ~OpTilingContextBuilder() override; - + ~OpTilingContextBuilder(); + /** + * @brief 获取当前Op的信息 + * @return OpInfo对象的引用 + */ + OpInfo &MutableOpInfo(); /** * @brief 设置Op的compileInfo指针 * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 @@ -44,57 +49,49 @@ class OpTilingContextBuilder : public OpContextBuilderBase &inputs); /** - * @brief 设置输出Tensor指针,用于在tiling计算时,可通过该builder类构造的上下文TilingContext获取相应的输出tensor指针 - * @param outputs 输出Tensor指针数组,所有权归调用者管理,调用者需要保证输出指针生命周期指针长于Build产生的ContextHolder对象 - * @return OpTilingContextBuilder对象用于链式调用 - */ + * @brief 设置context的output_values,values承载的类型为gert::Tensor*的指针数组 + * @note 对于数据依赖的算子,Tensor中的TensorData是需要有Host地址的正确值;对于非数据依赖算子,Tensor的TensorData应该为空指针 + * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 + * @param outputs 输出指针数组 + * @return InferShapeContextBuilder对象用于链式调用 + */ OpTilingContextBuilder &OutputTensors(const std::vector &outputs); /** * @brief 构建TilingContext对象 - * @return ContextHolder对象,包含TilingContext指针, - * @note 注意返回的ContextHolder对象的生命周期需要大于等于TilingContext对象的生命周期, - * 才能保证通过TilingContext获取的所有指针的有效性 + * @return ContextHolder对象,包含TilingContext指针 */ ContextHolder Build(); + + private: + std::unique_ptr impl_; }; } // namespace gert #endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_TILING_CONTEXT_BUILDER_H_ \ No newline at end of file diff --git a/inc/external/base/context_builder/op_tiling_parse_context_builder.h b/inc/external/base/context_builder/op_tiling_parse_context_builder.h index 28f4860139..ec8ede2ad4 100644 --- a/inc/external/base/context_builder/op_tiling_parse_context_builder.h +++ b/inc/external/base/context_builder/op_tiling_parse_context_builder.h @@ -9,74 +9,57 @@ #ifndef METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_TILING_PARSE_CONTEXT_BUILDER_H_ #define METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_TILING_PARSE_CONTEXT_BUILDER_H_ +#include +#include #include "base/context_builder/context_holder.h" +#include "base/context_builder/op_info.h" #include "exe_graph/runtime/continuous_vector.h" #include "exe_graph/runtime/tiling_parse_context.h" -#include "base/context_builder/op_context_builder_base.h" namespace gert { +class ContextBuilderImpl; /** - * @brief OpTilingParseContextBuilder类,用于构造TilingParseContext. - * @note OpTilingParseContextBuilder类的实例化对象用于构造算子tiling的执行上下文。 -*/ -class OpTilingParseContextBuilder : public OpContextBuilderBase { + * @brief TilingContextBuilder类用于构建算子执行上下文,包含Op信息、编译信息、平台信息、tiling数据、workspace内存等。 + * @note TilingContextBuilder类的实例化对象用于构造算子执行上下文。 + */ +class OpTilingParseContextBuilder { public: OpTilingParseContextBuilder(); - ~OpTilingParseContextBuilder() override; + ~OpTilingParseContextBuilder(); /** - * @brief 设置第index个实例输入的Tensor Description信息, - * 用于构造TilingParseContext的基类ExtendedKernelContext中的ComputeNodeInfo等信息 - * @param index 输入的索引,对应的是Op IR原型中的的输入实例Instance索引 - * @param dtype 输入Tensor的data type - * @param origin_format 输入Tensor的原始格式 - * @param storage_format 输入Tensor的存储格式 - * @param expand_dims_type 输入Tensor的ExpandDimsType,默认值为{} - * @return OpTilingParseContextBuilder对象引用,用于链式调用 + * @brief 获取当前Op的信息 + * @return OpInfo对象的引用 */ - OpTilingParseContextBuilder &InputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); + OpInfo &MutableOpInfo(); /** - * @brief 设置第index个实例输出的Tensor Description信息, - * 用于构造TilingParseContext的基类ExtendedKernelContext中的ComputeNodeInfo等信息 - * @param index 输出的索引,对应的是Op IR原型中的的输出实例Instance索引 - * @param dtype 输出Tensor的data type - * @param origin_format 输出Tensor的原始格式 - * @param storage_format 输出Tensor的存储格式 - * @param expand_dims_type 输出Tensor的ExpandDimsType,默认值为{} - * @return OpTilingParseContextBuilder对象引用,用于链式调用 - */ - OpTilingParseContextBuilder &OutputTensorDesc(size_t index, ge::DataType dtype, ge::Format origin_format, - ge::Format storage_format, - const gert::ExpandDimsType &expand_dims_type = {}); - /** - * @brief 设置Op的compileJson指针,json格式文件指针, 用于构造通过TilingParseContext获取的的compiled_json字段 + * @brief 设置Op的compileJson指针 * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 - * @param compiled_json 编译信息json文件指针 - * @return OpTilingParseContextBuilder对象用于链式调用 + * @param compile_json 编译信息json文件指针 + * @return TilingContextBuilder对象用于链式调用 */ OpTilingParseContextBuilder &CompiledJson(const ge::char_t *compiled_json); /** - * @brief 设置Op的CompiledInfo指针, 用于构造TilingParseContext中的CompiledInfo字段 + * @brief 设置Op的compileInfo指针 * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 * @param compile_info 编译信息指针 - * @return OpTilingParseContextBuilder对象用于链式调用 + * @return TilingContextBuilder对象用于链式调用 */ - OpTilingParseContextBuilder &CompiledInfo(const void *compile_info); + OpTilingParseContextBuilder &CompileInfo(const void *compile_info); /** - * @brief 设置Op的PlatFormInfo指针, 用于构造TilingParseContext的PlatformInfo字段 + * @brief 设置Op的平台信息 * @note 设置的所有输入数据类型,所有权归调用者管理,调用者需要保证输入指针生命周期指针长于Build产生的ContextHolder对象 * @param platform_info 平台信息指针 - * @return OpTilingParseContextBuilder对象用于链式调用 + * @return TilingContextBuilder对象用于链式调用 */ OpTilingParseContextBuilder &PlatformInfo(const void *platform_info); /** - * @brief 构建TilingParseContext对象 - * @return ContextHolder对象,包含TilingParseContext指针, - * @note 注意返回的ContextHolder对象的生命周期需要大于等于TilingParseContext对象的生命周期, - * 才能保证通过TilingParseContext获取的所有指针的有效性 + * @brief 构建TilingContext对象 + * @return ContextHolder对象,包含TilingContext指针 */ ContextHolder Build(); + + private: + std::unique_ptr impl_; }; } // namespace gert #endif // METADEF_INC_EXTERNAL_BASE_CONTEXT_BUILDER_TILING_PARSE_CONTEXT_BUILDER_H_ \ No newline at end of file diff --git a/tests/ut/base/testcase/context_builder_unittest.cc b/tests/ut/base/testcase/context_builder_unittest.cc index 09eaeb1b82..c90910c27a 100644 --- a/tests/ut/base/testcase/context_builder_unittest.cc +++ b/tests/ut/base/testcase/context_builder_unittest.cc @@ -2,6 +2,7 @@ #include #include #include "base/context_builder/context_holder.h" +#include "base/context_builder/op_info.h" #include "base/context_builder/op_kernel_run_context_builder.h" #include "base/context_builder/op_infer_datatype_context_builder.h" #include "base/context_builder/op_infer_shape_context_builder.h" @@ -25,18 +26,17 @@ class UtestContextBuilder : public testing::Test {}; TEST_F(UtestContextBuilder, CreateKernelRunContextOK) { OpKernelRunContextBuilder ctx_builder; gert::StorageShape shape0 = {{10, 20}, {10, 20}}; - auto holder = ctx_builder.OpType("Add") - .OpName("add_1") - .IONum(2, 1) - .InputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputTensorDesc(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .OutputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .Inputs({&shape0, &shape0}) - .Outputs({&shape0}) - .Build(); - auto ctx = reinterpret_cast(holder.GetContext()); + ctx_builder.MutableOpInfo() + .OpType("Add") + .OpName("add_1") + .IONum(2, 1) + .SetInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ); + auto holder = ctx_builder.Inputs({&shape0, &shape0}).Outputs({&shape0}).Build(); + auto ctx = (KernelContext *) holder.GetContext(); EXPECT_NE(ctx, nullptr); - auto ctx_compute_node_info = static_cast(ctx->GetComputeNodeExtend()); + auto ctx_compute_node_info = (gert::ComputeNodeInfo *) (ctx->GetComputeNodeExtend()); EXPECT_NE(ctx_compute_node_info, nullptr); EXPECT_EQ(std::string(ctx_compute_node_info->GetNodeType()), std::string("Add")); EXPECT_EQ(std::string(ctx_compute_node_info->GetNodeName()), std::string("add_1")); @@ -56,15 +56,14 @@ TEST_F(UtestContextBuilder, CreateKernelRunContextOK) { TEST_F(UtestContextBuilder, CreateKernelRunContextFailed) { OpKernelRunContextBuilder ctx_builder; gert::StorageShape shape0 = {{10, 20}, {10, 20}}; - auto holder = ctx_builder.OpName("add_1") - .IONum(1, 1) - .InputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputTensorDesc(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .OutputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .Inputs({&shape0, &shape0}) - .Outputs({&shape0}) - .Build(); - auto ctx = reinterpret_cast(holder.GetContext()); + ctx_builder.MutableOpInfo() + .OpName("add_1") + .IONum(1, 1) + .SetInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ); + auto holder = ctx_builder.Inputs({&shape0, &shape0}).Outputs({&shape0}).Build(); + auto ctx = (KernelContext *) holder.GetContext(); EXPECT_EQ(ctx, nullptr); } @@ -78,15 +77,16 @@ TEST_F(UtestContextBuilder, CreateInferDataTypeContextOK) { std::vector input_dtype_ref = {&dtype0, &dtype1, &dtype2, &dtype3}; std::vector output_dtype_ref = {&dtype4}; - auto holder = ctx_builder.OpType("Concat") - .OpName("concat_1") - .IOInstanceNum({4}, {1}) - .InputTensorDesc(0, dtype0, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensorDesc(1, dtype1, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensorDesc(2, dtype2, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensorDesc(3, dtype3, ge::FORMAT_ND, ge::FORMAT_ND) - .OutputTensorDesc(0, ge::FORMAT_ND, ge::FORMAT_ND) - .Build(); + ctx_builder.MutableOpInfo() + .OpType("Concat") + .OpName("concat_1") + .IOInstanceNum({4}, {1}) + .SetInputTd(0, dtype0, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(1, dtype1, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(2, dtype2, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(3, dtype3, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetOutputTd(0, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND); + auto holder = ctx_builder.Inputs(input_dtype_ref).Outputs(output_dtype_ref).Build(); auto ctx = holder.GetContext(); EXPECT_NE(ctx, nullptr); auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); @@ -100,14 +100,14 @@ TEST_F(UtestContextBuilder, CreateInferDataTypeContextOK) { const CompileTimeTensorDesc *info_input_0 = ctx_compute_node_info->GetInputTdInfo(0); EXPECT_NE(info_input_0, nullptr); EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_ND); - EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_ND); + EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_FRACTAL_NZ); ge::DataType expected_datatype_0 = ge::DT_FLOAT; ge::DataType expected_datatype_1 = ge::DT_FLOAT16; EXPECT_EQ(ctx->GetInputDataType(0), expected_datatype_0); EXPECT_EQ(ctx->GetInputDataType(1), expected_datatype_1); EXPECT_EQ(ctx->GetInputDataType(2), expected_datatype_0); EXPECT_EQ(ctx->GetInputDataType(3), expected_datatype_1); - EXPECT_EQ(ctx->GetOutputDataType(0), ge::DT_MAX); + EXPECT_EQ(ctx->GetOutputDataType(0), expected_datatype_1); } TEST_F(UtestContextBuilder, CreateInferDataTypeWithTypeContextOK) { @@ -120,15 +120,16 @@ TEST_F(UtestContextBuilder, CreateInferDataTypeWithTypeContextOK) { std::vector input_dtype_ref = {dtype0, dtype1, dtype2, dtype3}; std::vector output_dtype_ref = {&dtype4}; - auto holder = ctx_builder.OpType("Concat") - .OpName("concat_1") - .IOInstanceNum({4}, {1}) - .InputTensorDesc(0, dtype0, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensorDesc(1, dtype1, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensorDesc(2, dtype2, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensorDesc(3, dtype3, ge::FORMAT_ND, ge::FORMAT_ND) - .OutputTensorDesc(0, ge::FORMAT_ND, ge::FORMAT_ND) - .Build(); + ctx_builder.MutableOpInfo() + .OpType("Concat") + .OpName("concat_1") + .IOInstanceNum({4}, {1}) + .SetInputTd(0, dtype0, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(1, dtype1, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(2, dtype2, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(3, dtype3, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetOutputTd(0, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND); + auto holder = ctx_builder.Inputs(input_dtype_ref).Outputs(output_dtype_ref).Build(); auto ctx = holder.GetContext(); EXPECT_NE(ctx, nullptr); auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); @@ -142,24 +143,30 @@ TEST_F(UtestContextBuilder, CreateInferDataTypeWithTypeContextOK) { const CompileTimeTensorDesc *info_input_0 = ctx_compute_node_info->GetInputTdInfo(0); EXPECT_NE(info_input_0, nullptr); EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_ND); - EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_ND); + EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_FRACTAL_NZ); ge::DataType expected_datatype_0 = ge::DT_FLOAT; ge::DataType expected_datatype_1 = ge::DT_FLOAT16; EXPECT_EQ(ctx->GetInputDataType(0), expected_datatype_0); EXPECT_EQ(ctx->GetInputDataType(1), expected_datatype_1); EXPECT_EQ(ctx->GetInputDataType(2), expected_datatype_0); EXPECT_EQ(ctx->GetInputDataType(3), expected_datatype_1); - EXPECT_EQ(ctx->GetOutputDataType(0), ge::DT_MAX); + EXPECT_EQ(ctx->GetOutputDataType(0), expected_datatype_1); } TEST_F(UtestContextBuilder, CreateInferShapeContextFailed) { OpInferShapeContextBuilder ctx_builder; gert::StorageShape shape0 = {{10, 20}, {10, 20}}; - gert::Tensor tensor; - tensor.MutableStorageShape() = shape0.GetStorageShape(); - tensor.MutableOriginShape() = shape0.GetOriginShape(); - ctx_builder.OpType("DIY").OpName("diy_1").IONum(2, 1); - auto holder = ctx_builder.InputTensors({&tensor, &tensor}).Build(); + ctx_builder.MutableOpInfo() + .OpType("DIY") + .OpName("diy_1") + .IONum(2, 1) + .SetInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .SetOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ); + auto holder = + ctx_builder.InputTensors({(gert::Tensor *) &shape0, (gert::Tensor *) &shape0}).OutputShapes({&shape0}).Build(); auto ctx = holder.GetContext(); EXPECT_NE(ctx, nullptr); auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); @@ -178,19 +185,23 @@ TEST_F(UtestContextBuilder, CreateInferShapeContextOK) { StorageShape shape1 = {{2, 3, 4, 5}, {5, 4, 3, 2}}; StorageShape shape2 = {{3, 4, 5, 6}, {6, 5, 4, 3}}; StorageShape shape3 = {{4, 5, 6, 7}, {7, 6, 5, 4}}; - StorageFormat format{FORMAT_ND, FORMAT_FRACTAL_NZ, {}}; - gert::Tensor tensor0{shape0, format, ge::DT_FLOAT}; - gert::Tensor tensor1{shape1, format, ge::DT_FLOAT}; - gert::Tensor tensor2{shape2, format, ge::DT_FLOAT}; - gert::Tensor tensor3{shape3, format, ge::DT_FLOAT}; - - std::vector input_dtype_ref = {&tensor0, &tensor1, &tensor2, &tensor3}; - auto holder = ctx_builder.OpType("DIY") - .OpName("diy_1") - .IOInstanceNum({1, 1, 1, 1}, {1}) - .OutputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_NCHW) - .InputTensors(input_dtype_ref) - .Build(); + StorageShape outshape0 = {}; + StorageShape outshape1 = {}; + + std::vector input_dtype_ref = {(gert::Tensor *) &shape0, (gert::Tensor *) &shape1, + (gert::Tensor *) &shape2, (gert::Tensor *) &shape3}; + std::vector output_dtype_ref = {&outshape0, &outshape1}; + ctx_builder.MutableOpInfo() + .OpType("DIY") + .OpName("diy_1") + .IOInstanceNum({1, 1, 1, 1}, {2}) + .SetInputTd(0, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(1, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(2, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetInputTd(3, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetOutputTd(0, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND) + .SetOutputTd(1, ge::DT_FLOAT, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND); + auto holder = ctx_builder.InputTensors(input_dtype_ref).OutputShapes(output_dtype_ref).Build(); auto ctx = holder.GetContext(); EXPECT_NE(ctx, nullptr); auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); @@ -200,11 +211,11 @@ TEST_F(UtestContextBuilder, CreateInferShapeContextOK) { EXPECT_EQ(ctx_compute_node_info->GetIrInputsNum(), 4); EXPECT_EQ(ctx_compute_node_info->GetIrOutputsNum(), 1); EXPECT_EQ(ctx_compute_node_info->GetInputsNum(), 4); - EXPECT_EQ(ctx_compute_node_info->GetOutputsNum(), 1); + EXPECT_EQ(ctx_compute_node_info->GetOutputsNum(), 2); const CompileTimeTensorDesc *info_input_0 = ctx_compute_node_info->GetInputTdInfo(0); EXPECT_NE(info_input_0, nullptr); - EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_FRACTAL_NZ); - EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_ND); + EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_ND); + EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_FRACTAL_NZ); EXPECT_NE(ctx->GetInputShape(0), nullptr); EXPECT_EQ(*(ctx->GetInputShape(0)), shape0.GetOriginShape()); EXPECT_NE(ctx->GetInputShape(1), nullptr); @@ -214,66 +225,35 @@ TEST_F(UtestContextBuilder, CreateInferShapeContextOK) { EXPECT_NE(ctx->GetInputShape(3), nullptr); EXPECT_EQ(*(ctx->GetInputShape(3)), shape3.GetOriginShape()); EXPECT_NE(ctx->GetOutputShape(0), nullptr); - EXPECT_EQ(ctx->GetOutputShape(0)->GetDimNum(), 0); - EXPECT_EQ(ctx->GetComputeNodeInfo()->GetOutputTdInfo(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(ctx->GetComputeNodeInfo()->GetOutputTdInfo(0)->GetOriginFormat(), FORMAT_ND); - EXPECT_EQ(ctx->GetComputeNodeInfo()->GetOutputTdInfo(0)->GetStorageFormat(), FORMAT_NCHW); -} - -TEST_F(UtestContextBuilder, CreateInferShapeContextOutOfRangeFailed) { - OpInferShapeContextBuilder ctx_builder; - StorageShape shape0 = {{1, 2, 3, 4}, {4, 3, 2, 1}}; - StorageShape shape1 = {{2, 3, 4, 5}, {5, 4, 3, 2}}; - StorageShape shape2 = {{3, 4, 5, 6}, {6, 5, 4, 3}}; - StorageShape shape3 = {{4, 5, 6, 7}, {7, 6, 5, 4}}; - StorageFormat format{FORMAT_ND, FORMAT_FRACTAL_NZ, {}}; - gert::Tensor tensor0{shape0, format, ge::DT_FLOAT}; - gert::Tensor tensor1{shape1, format, ge::DT_FLOAT}; - gert::Tensor tensor2{shape2, format, ge::DT_FLOAT}; - gert::Tensor tensor3{shape3, format, ge::DT_FLOAT}; - - std::vector input_dtype_ref = {&tensor0, &tensor1, &tensor2, &tensor3}; - auto holder = ctx_builder.OpType("DIY") - .OpName("diy_1") - .IOInstanceNum({1, 1, 1, 1}, {1}) - .OutputTensorDesc(1, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND) - .InputTensors(input_dtype_ref) - .Build(); - auto ctx = holder.GetContext(); - EXPECT_NE(ctx, nullptr); - EXPECT_NE(ctx->GetOutputShape(0), nullptr); - EXPECT_EQ(ctx->GetOutputShape(0)->GetDimNum(), 0); - EXPECT_EQ(ctx->GetOutputShape(1), nullptr); + EXPECT_EQ(*(ctx->GetOutputShape(0)), outshape0.GetOriginShape()); + EXPECT_NE(ctx->GetOutputShape(1), nullptr); + EXPECT_EQ(*(ctx->GetOutputShape(1)), outshape1.GetOriginShape()); } TEST_F(UtestContextBuilder, CreateInferShapeRangeContextOK) { OpInferShapeRangeContextBuilder ctx_builder; - gert::StorageShape xShapeMin{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; - gert::StorageShape xShapeMax{{10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}}; - gert::StorageShape wShapeMin{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; - gert::StorageShape wShapeMax{{10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}}; - gert::Shape yShapeMinNull{1, 1, 1, 1, 1}; - gert::Shape yShapeMaxNull{10, 10, 10, 10, 20}; - gert::StorageShape yShapeMin{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; - gert::StorageShape yShapeMax{{10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}}; - StorageFormat format{FORMAT_NCDHW, FORMAT_RESERVED, {}}; - - gert::Tensor xTensorMin{xShapeMin, format, ge::DT_INT8}; - gert::Tensor xTensorMax{xShapeMax, format, ge::DT_INT8}; - gert::Range xShapeRange(&xTensorMin, &xTensorMax); - - gert::Tensor wTensorMin{wShapeMin, format, ge::DT_INT8}; - gert::Tensor wTensorMax{wShapeMax, format, ge::DT_INT8}; - - gert::Range wShapeRange(&wTensorMin, &wTensorMax); + gert::Shape xShapeMin{1, 1, 1, 1, 1}; + gert::Shape xShapeMax{10, 10, 10, 10, 20}; + gert::Shape wShapeMin{1, 1, 1, 1, 1}; + gert::Shape wShapeMax{10, 10, 10, 10, 20}; + gert::Shape yShapeMinNull{}; + gert::Shape yShapeMaxNull{}; + gert::Shape yShapeMin{1, 1, 1, 1, 1}; + gert::Shape yShapeMax{10, 10, 10, 10, 20}; + + gert::Range xShapeRange((gert::Tensor *)&xShapeMin, (gert::Tensor *)&xShapeMax); + gert::Range wShapeRange((gert::Tensor *)&wShapeMin, (gert::Tensor *)&wShapeMax); gert::Range yShapeRange(&yShapeMinNull, &yShapeMaxNull); - auto holder = ctx_builder.IONum(2, 1) - .OutputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .OpType("DIY") - .OpName("diy_1") - .InputTensorsRange({&xShapeRange, &wShapeRange}) - .Build(); + ctx_builder.MutableOpInfo() + .IONum(2, 1) + .OpType("DIY") + .OpName("diy_1") + .SetInputTd(0, ge::DT_INT8, ge::Format::FORMAT_NCDHW, ge::Format::FORMAT_RESERVED) + .SetInputTd(1, ge::DT_INT8, ge::Format::FORMAT_NCDHW, ge::Format::FORMAT_RESERVED) + .SetOutputTd(0, ge::DT_FLOAT16, ge::Format::FORMAT_NCDHW, ge::Format::FORMAT_RESERVED); + + auto holder = ctx_builder.InputTensors({&xShapeRange, &wShapeRange}).OutputShapes({&yShapeRange}).Build(); auto ctx = holder.GetContext(); EXPECT_NE(ctx, nullptr); @@ -293,62 +273,22 @@ TEST_F(UtestContextBuilder, CreateInferShapeRangeContextOK) { EXPECT_EQ(info_output_0->GetOriginFormat(), ge::FORMAT_NCDHW); EXPECT_EQ(info_output_0->GetStorageFormat(), ge::FORMAT_RESERVED); EXPECT_NE(ctx->GetInputShapeRange(0), nullptr); - EXPECT_EQ(*(ctx->GetInputShapeRange(0)->GetMin()), xShapeMin.GetOriginShape()); - EXPECT_EQ(*(ctx->GetInputShapeRange(0)->GetMax()), xShapeMax.GetOriginShape()); + EXPECT_EQ(*(ctx->GetInputShapeRange(0)->GetMin()), xShapeMin); + EXPECT_EQ(*(ctx->GetInputShapeRange(0)->GetMax()), xShapeMax); EXPECT_NE(ctx->GetInputShapeRange(1), nullptr); - EXPECT_EQ(*(ctx->GetInputShapeRange(1)->GetMin()), wShapeMin.GetOriginShape()); - EXPECT_EQ(*(ctx->GetInputShapeRange(1)->GetMax()), wShapeMax.GetOriginShape()); + EXPECT_EQ(*(ctx->GetInputShapeRange(1)->GetMin()), wShapeMin); + EXPECT_EQ(*(ctx->GetInputShapeRange(1)->GetMax()), wShapeMax); EXPECT_NE(ctx->GetOutputShapeRange(0), nullptr); - EXPECT_NE(ctx->GetOutputShapeRange(0)->GetMin(), nullptr); - EXPECT_NE(ctx->GetOutputShapeRange(0)->GetMax(), nullptr); - EXPECT_EQ(ctx->GetOutputShapeRange(0)->GetMin()->GetDimNum(), 0); - EXPECT_EQ(ctx->GetOutputShapeRange(0)->GetMax()->GetDimNum(), 0); + EXPECT_EQ(*(ctx->GetOutputShapeRange(0)->GetMin()), yShapeMinNull); + EXPECT_EQ(*(ctx->GetOutputShapeRange(0)->GetMax()), yShapeMinNull); ctx->GetOutputShapeRange(0)->GetMin()->SetDimNum(5); ctx->GetOutputShapeRange(0)->GetMax()->SetDimNum(5); for (size_t i = 0; i < 5; i++) { - ctx->GetOutputShapeRange(0)->GetMin()->SetDim(i, yShapeMin.GetOriginShape()[i]); - ctx->GetOutputShapeRange(0)->GetMax()->SetDim(i, yShapeMax.GetOriginShape()[i]); + ctx->GetOutputShapeRange(0)->GetMin()->SetDim(i, yShapeMin[i]); + ctx->GetOutputShapeRange(0)->GetMax()->SetDim(i, yShapeMax[i]); } - EXPECT_EQ(*(ctx->GetOutputShapeRange(0)->GetMin()), yShapeMin.GetOriginShape()); - EXPECT_EQ(*(ctx->GetOutputShapeRange(0)->GetMax()), yShapeMax.GetOriginShape()); -} - -TEST_F(UtestContextBuilder, CreateInferShapeRangeContextFailed) { - OpInferShapeRangeContextBuilder ctx_builder; - gert::StorageShape xShapeMin{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; - gert::StorageShape xShapeMax{{10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}}; - gert::StorageShape wShapeMin{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; - gert::StorageShape wShapeMax{{10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}}; - gert::Shape yShapeMinNull{1, 1, 1, 1, 1}; - gert::Shape yShapeMaxNull{10, 10, 10, 10, 20}; - gert::StorageShape yShapeMin{{1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}; - gert::StorageShape yShapeMax{{10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}}; - StorageFormat format{FORMAT_NCDHW, FORMAT_RESERVED, {}}; - - gert::Tensor xTensorMin{xShapeMin, format, ge::DT_FLOAT}; - gert::Tensor xTensorMax{xShapeMax, format, ge::DT_INT8}; - gert::Range xShapeRange(&xTensorMin, &xTensorMax); - - gert::Tensor wTensorMin{wShapeMin, format, ge::DT_INT8}; - gert::Tensor wTensorMax{wShapeMax, format, ge::DT_INT8}; - - gert::Range wShapeRange(&wTensorMin, &wTensorMax); - gert::Range yShapeRange(&yShapeMinNull, &yShapeMaxNull); - - auto holder = ctx_builder.IONum(2, 1) - .OutputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .OpType("DIY") - .OpName("diy_1") - .InputTensorsRange({&xShapeRange, &wShapeRange}) - .Build(); - - auto ctx = holder.GetContext(); - EXPECT_NE(ctx, nullptr); - const CompileTimeTensorDesc *info_input_0 = ctx->GetInputDesc(0); - EXPECT_NE(info_input_0, nullptr); - EXPECT_EQ(info_input_0->GetDataType(), ge::DT_MAX); - EXPECT_EQ(info_input_0->GetOriginFormat(), ge::FORMAT_MAX); - EXPECT_EQ(info_input_0->GetStorageFormat(), ge::FORMAT_MAX); + EXPECT_EQ(*(ctx->GetOutputShapeRange(0)->GetMin()), yShapeMin); + EXPECT_EQ(*(ctx->GetOutputShapeRange(0)->GetMax()), yShapeMax); } TEST_F(UtestContextBuilder, CreateTilingContextOK) { @@ -368,36 +308,39 @@ TEST_F(UtestContextBuilder, CreateTilingContextOK) { gert::StorageShape result({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); uint8_t data_x[1] = {9}; - gert::Tensor x_tensor(x, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, - ge::DT_FLOAT, (void *) data_x); - gert::Tensor resultIn_tensor(resultIn, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, - TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); - gert::Tensor gammax_tensor(gamma, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, - ge::DT_FLOAT, nullptr); - gert::Tensor beta_tensor(beta, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, - ge::DT_FLOAT, nullptr); - gert::Tensor result_tensor(result, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, - TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); + gert::Tensor x_tensor(x, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, ge::DT_FLOAT, (void*)data_x); + gert::Tensor resultIn_tensor(resultIn, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); + gert::Tensor gammax_tensor(gamma, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); + gert::Tensor beta_tensor(beta, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); + gert::Tensor result_tensor(result, {ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, ExpandDimsType()}, TensorPlacement::kOnHost, ge::DT_FLOAT, nullptr); auto tmp_tiling_data = gert::TilingData::CreateCap(100); uint8_t tmp_compile_info[] = {1, 2, 3, 4, 5, 6, 7}; uint8_t tmp_platform_info[] = {1, 2, 3, 4, 5, 6, 7}; - int32_t deterministic = 1; + int32_t deterministic = 10; OpTilingContextBuilder ctx_builder; - auto holder = ctx_builder.OpName("tmp") - .OpType("DIY") - .IONum(4, 1) - .AppendAttr(int64_t(1)) - .AppendAttr(bool(true)) - .AppendAttr(float(0.3)) - .AppendAttr(AscendString("my_info")) - .AppendAttr(std::vector({true, false, true})) - .AppendAttr(std::vector({1, 2, 3})) - .AppendAttr(std::vector({0.1, 0.2, 0.3})) - .AppendAttr(std::vector({"123", "234"})) - .AppendAttr(std::vector>({{1, 2, 3}, {4, 5, 6}})) - .TilingData(reinterpret_cast(tmp_tiling_data.get())) + + ctx_builder.MutableOpInfo() + .OpName("tmp") + .OpType("DIY") + .IONum(4, 1) + .SetInputTd(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, x) + .SetInputTd(1, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, resultIn) + .SetInputTd(2, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, gamma) + .SetInputTd(3, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, beta) + .SetOutputTd(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, result) + .Attr(AscendString("axis"), int64_t(1)) + .Attr(AscendString("transpose"), bool(true)) + .Attr(AscendString("eps"), float(0.3)) + .Attr(AscendString("info"), AscendString("my_info")) + .Attr(AscendString("bool_vec"), std::vector({true, false, true})) + .Attr(AscendString("int_vec"), std::vector({1, 2, 3})) + .Attr(AscendString("float_vec"), std::vector({0.1, 0.2, 0.3})) + .Attr(AscendString("str_vec"), std::vector({"123", "234"})) + .Attr(AscendString("vec_vec_int"), std::vector>({{1, 2, 3}, {4, 5, 6}})); + + auto holder = ctx_builder.TilingData((gert::TilingData *) tmp_tiling_data.get()) .Workspace(ws_ptr) .CompileInfo(tmp_compile_info) .Deterministic(deterministic) @@ -420,10 +363,10 @@ TEST_F(UtestContextBuilder, CreateTilingContextOK) { EXPECT_EQ(ctx->GetInputTensor(0)->GetSize(), x_tensor.GetSize()); EXPECT_EQ(ctx->GetOutputShape(0)->GetOriginShape(), resultShape); EXPECT_EQ(ctx->GetOutputShape(0)->GetStorageShape(), resultShape); - EXPECT_EQ(static_cast(ctx->GetWorkspaceSizes(4096)), static_cast(ws_ptr->GetData())); - EXPECT_EQ(static_cast(ctx->GetPlatformInfo()), static_cast(tmp_platform_info)); + EXPECT_EQ((void *) (ctx->GetWorkspaceSizes(4096)), (void *) ws_ptr->GetData()); + EXPECT_EQ((void *) ctx->GetPlatformInfo(), (void *) tmp_platform_info); EXPECT_EQ(ctx->GetDeterministic(), deterministic); - EXPECT_EQ(static_cast(ctx->GetRawTilingData()), static_cast(tmp_tiling_data.get())); + EXPECT_EQ((void *) ctx->GetRawTilingData(), (void *) tmp_tiling_data.get()); EXPECT_EQ(*(ctx->GetAttrs()->GetInt(0)), 1); EXPECT_EQ(*(ctx->GetAttrs()->GetBool(1)), true); EXPECT_FLOAT_EQ(*(ctx->GetAttrs()->GetFloat(2)), 0.3); @@ -448,59 +391,19 @@ TEST_F(UtestContextBuilder, CreateTilingContextOK) { EXPECT_EQ(((int64_t *) (int_vec_vec->Get(1)->GetData()))[2], 6); } -TEST_F(UtestContextBuilder, CreateTilingContextTilingDataSizeOK) { - auto workspace_size_holer = gert::ContinuousVector::Create(4096); - auto ws_ptr = reinterpret_cast(workspace_size_holer.get()); - uint8_t tmp_compile_info[] = {1, 2, 3, 4, 5, 6, 7}; - uint8_t tmp_platform_info[] = {1, 2, 3, 4, 5, 6, 7}; - - OpTilingContextBuilder ctx_builder; - auto holder = ctx_builder.OpName("tmp") - .OpType("DIY") - .IONum(4, 1) - .CompileInfo(tmp_compile_info) - .PlatformInfo(tmp_platform_info) - .TilingDataSize(100) - .Build(); +TEST_F(UtestContextBuilder, CreateTilingParseContextOK) { - auto ctx = holder.GetContext(); - EXPECT_NE(ctx, nullptr); - EXPECT_NE(ctx->GetRawTilingData(), nullptr); - EXPECT_EQ(ctx->GetRawTilingData()->GetCapacity(), 100); - - auto tmp_tiling_data_120 = gert::TilingData::CreateCap(120); - OpTilingContextBuilder ctx_builder2; - auto holder2 = ctx_builder2.OpName("tmp") - .OpType("DIY") - .IONum(1, 1) - .CompileInfo(tmp_compile_info) - .PlatformInfo(tmp_platform_info) - .TilingDataSize(100) - .TilingData(reinterpret_cast(tmp_tiling_data_120.get())) - .Workspace(ws_ptr) - .Build(); - ctx = holder2.GetContext(); - EXPECT_NE(ctx, nullptr); - EXPECT_NE(ctx->GetRawTilingData(), nullptr); - EXPECT_EQ(ctx->GetRawTilingData(), reinterpret_cast(tmp_tiling_data_120.get())); - EXPECT_EQ(ctx->GetRawTilingData()->GetCapacity(), 120); - - holder2 = ctx_builder2.OpName("tmp") - .OpType("DIY") - .IONum(1, 1) - .CompileInfo(tmp_compile_info) - .PlatformInfo(tmp_platform_info) - .TilingData(reinterpret_cast(tmp_tiling_data_120.get())) - .TilingDataSize(100) - .Workspace(ws_ptr) - .Build(); - ctx = holder.GetContext(); - EXPECT_NE(ctx, nullptr); - EXPECT_NE(ctx->GetRawTilingData(), nullptr); - EXPECT_EQ(ctx->GetRawTilingData()->GetCapacity(), 100); -} + gert::Shape shape_0{1, 1, 1, 1, 1}; + gert::Shape shape_1{10, 10, 10, 10, 20}; + gert::Shape shape_2{1, 1, 1, 1, 1}; + gert::Shape shape_3{10, 10, 10, 10, 20}; + gert::Shape resultShape{10, 10, 10, 10, 20}; -TEST_F(UtestContextBuilder, CreateTilingParseContextOK) { + gert::StorageShape x({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); + gert::StorageShape resultIn({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); + gert::StorageShape gamma({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); + gert::StorageShape beta({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); + gert::StorageShape result({10, 10, 10, 10, 20}, {10, 10, 10, 10, 20}); std::string tmp_compile_json = "123.json"; @@ -509,16 +412,18 @@ TEST_F(UtestContextBuilder, CreateTilingParseContextOK) { OpTilingParseContextBuilder ctx_builder; - auto holder = ctx_builder.OpName("tmp") - .OpType("DIY") - .IONum(4, 1) - .InputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .InputTensorDesc(1, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .InputTensorDesc(2, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .InputTensorDesc(3, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .OutputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED) - .CompiledJson(tmp_compile_json.c_str()) - .CompiledInfo(tmp_compile_info) + ctx_builder.MutableOpInfo() + .OpName("tmp") + .OpType("DIY") + .IONum(4, 1) + .SetInputTd(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, x) + .SetInputTd(1, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, resultIn) + .SetInputTd(2, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, gamma) + .SetInputTd(3, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, beta) + .SetOutputTd(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, result); + + auto holder = ctx_builder.CompiledJson(tmp_compile_json.c_str()) + .CompileInfo(tmp_compile_info) .PlatformInfo(tmp_platform_info) .Build(); @@ -533,65 +438,3 @@ TEST_F(UtestContextBuilder, CreateTilingParseContextOK) { EXPECT_EQ(ctx->GetCompiledJson(), tmp_compile_json.c_str()); EXPECT_EQ(ctx->GetCompiledInfo(), tmp_compile_info); } - -TEST_F(UtestContextBuilder, CreateTilingParseContextExpandDimsTypeOK) { - std::string tmp_compile_json = "123.json"; - - uint8_t tmp_platform_info[] = {1, 2, 3, 4, 5, 6, 7}; - uint8_t tmp_compile_info[] = {1, 2, 3, 4, 5, 6, 7}; - - OpTilingParseContextBuilder ctx_builder; - gert::ExpandDimsType expand_dims_type11; - expand_dims_type11.SetExpandIndex(11); - gert::ExpandDimsType expand_dims_type12; - expand_dims_type12.SetExpandIndex(12); - auto holder = ctx_builder.OpName("tmp") - .OpType("DIY") - .IONum(2, 1) - .InputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, expand_dims_type11) - .InputTensorDesc(1, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, expand_dims_type12) - .OutputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, gert::ExpandDimsType()) - .CompiledJson(tmp_compile_json.c_str()) - .CompiledInfo(tmp_compile_info) - .PlatformInfo(tmp_platform_info) - .Build(); - - auto ctx = holder.GetContext(); - EXPECT_NE(ctx, nullptr); - auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); - EXPECT_NE(ctx_compute_node_info, nullptr); - EXPECT_EQ(ctx->GetInputDesc(0)->GetExpandDimsType(), expand_dims_type11); - EXPECT_EQ(ctx->GetInputDesc(1)->GetExpandDimsType(), expand_dims_type12); -} - -TEST_F(UtestContextBuilder, CreateTilingParseContextExpandDimsTypeFailed) { - std::string tmp_compile_json = "123.json"; - - uint8_t tmp_platform_info[] = {1, 2, 3, 4, 5, 6, 7}; - uint8_t tmp_compile_info[] = {1, 2, 3, 4, 5, 6, 7}; - - OpTilingParseContextBuilder ctx_builder; - gert::ExpandDimsType expand_dims_type11; - expand_dims_type11.SetExpandIndex(11); - gert::ExpandDimsType expand_dims_type12; - expand_dims_type12.SetExpandIndex(12); - auto holder = ctx_builder.OpName("tmp") - .OpType("DIY") - .IONum(2, 1) - .InputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, expand_dims_type11) - .InputTensorDesc(1, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, expand_dims_type12) - .InputTensorDesc(2, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, expand_dims_type12) - .OutputTensorDesc(0, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, gert::ExpandDimsType()) - .OutputTensorDesc(1, ge::DT_FLOAT, ge::FORMAT_NCDHW, ge::FORMAT_RESERVED, gert::ExpandDimsType()) - .CompiledJson(tmp_compile_json.c_str()) - .CompiledInfo(tmp_compile_info) - .PlatformInfo(tmp_platform_info) - .Build(); - - auto ctx = holder.GetContext(); - EXPECT_NE(ctx, nullptr); - auto ctx_compute_node_info = ctx->GetComputeNodeInfo(); - EXPECT_NE(ctx_compute_node_info, nullptr); - EXPECT_EQ(ctx->GetInputDesc(0)->GetExpandDimsType(), expand_dims_type11); - EXPECT_EQ(ctx->GetInputDesc(1)->GetExpandDimsType(), expand_dims_type12); -} -- Gitee