diff --git a/exe_graph/lowering/value_holder.cc b/exe_graph/lowering/value_holder.cc index 5257ca6529bd2b7f3ef49ec0ff7dc4b7d18d9804..06d729aa485b73c4010f758745d834911b58f5ec 100644 --- a/exe_graph/lowering/value_holder.cc +++ b/exe_graph/lowering/value_holder.cc @@ -211,7 +211,7 @@ std::vector ValueHolder::CreateDataOutput(const char *node_type, } return CreateFromNode(node, out_count); } -ValueHolderPtr ValueHolder::CreateVoid(const char *node_type, const vector &inputs) { +ValueHolderPtr ValueHolder::CreateVoid(const char *node_type, const std::vector &inputs) { auto node = CreateNode(node_type, inputs, 0); GE_ASSERT_NOTNULL(node); return CreateFromNode(node, -1, kOutput); @@ -237,7 +237,7 @@ ValueHolderPtr ValueHolder::CreateFeed(int64_t index) { return CreateFromNode(node, 0, kFeed); } -ValueHolderPtr ValueHolder::CreateSingleDataOutput(const char *node_type, const vector &inputs) { +ValueHolderPtr ValueHolder::CreateSingleDataOutput(const char *node_type, const std::vector &inputs) { auto holders = CreateDataOutput(node_type, inputs, 1U); if (holders.empty()) { return nullptr; @@ -325,7 +325,7 @@ ge::graphStatus ValueHolder::RefFrom(const ValueHolderPtr &other) { return ge::GRAPH_SUCCESS; } ValueHolderPtr ValueHolder::CreateVoidGuarder(const char *node_type, const ValueHolderPtr &resource, - const vector &args) { + const std::vector &args) { std::vector inputs; inputs.reserve(args.size() + 1); inputs.emplace_back(resource); diff --git a/graph/CMakeLists.txt b/graph/CMakeLists.txt index 6f7fc418eb0301f84348cc62ef6b798466ec1b60..ab31dc18c1ef87d935df0aed1988a401c86ac605 100755 --- a/graph/CMakeLists.txt +++ b/graph/CMakeLists.txt @@ -11,20 +11,10 @@ set(GRAPH_SOURCE_LIST "aligned_ptr.cc" "compute_graph.cc" "ascend_string.cc" - "axis_type_info.cc" - "gnode.cc" - "graph.cc" - "inference_context.cc" - "shape_refiner.cc" - "format_refiner.cc" - "ref_relation.cc" "model.cc" "model_serialize.cc" "node.cc" "op_desc.cc" - "operator.cc" - "operator_factory.cc" - "operator_factory_impl.cc" "ge_attr_define.cc" "ge_tensor.cc" "common/large_bm.cc" @@ -32,34 +22,20 @@ set(GRAPH_SOURCE_LIST "common/hyper_status.cc" "detail/attributes_holder.cc" "utils/anchor_utils.cc" - "utils/tuning_utils.cc" "utils/graph_utils.cc" - "utils/ffts_graph_utils.cc" "utils/dumper/ge_graph_dumper.cc" "utils/trace/trace_manager.cc" "utils/ge_ir_utils.cc" "utils/node_utils.cc" - "utils/op_desc_utils.cc" "utils/type_utils.cc" "utils/tensor_utils.cc" "utils/constant_utils.cc" "utils/connection_matrix.cc" "utils/cycle_detector.cc" "tensor.cc" - "debug/graph_debug.cc" - "opsproto/opsproto_manager.cc" - "${METADEF_DIR}/ops/op_imp.cpp" "option/ge_context.cc" "option/ge_local_context.cc" - "runtime_inference_context.cc" - "${METADEF_DIR}/third_party/transformer/src/axis_util.cc" - "${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cc" - "${METADEF_DIR}/third_party/transformer/src/expand_dimension.cc" - "${METADEF_DIR}/third_party/transformer/src/transfer_range_according_to_format.cc" - "${METADEF_DIR}/third_party/transformer/src/transfer_shape_utils.cc" - "utils/transformer_utils.cc" "utils/file_utils.cc" - "resource_context_mgr.cc" "serialization/attr_serializer.cc" "serialization/string_serializer.cc" "serialization/data_type_serializer.cc" @@ -76,7 +52,6 @@ set(GRAPH_SOURCE_LIST "serialization/list_list_float_serializer.cc" "serialization/attr_serializer_registry.cc" "small_vector.cc" - "operator_impl.cc" "compile_cache_policy/compile_cache_policy.cc" "compile_cache_policy/compile_cache_desc.cc" "compile_cache_policy/policy_register.cc" @@ -87,9 +62,41 @@ set(GRAPH_SOURCE_LIST "profiler.cc" ) + +SET(OPERATOR_SRC_LIST + ${METADEF_DIR}/graph/axis_type_info.cc + ${METADEF_DIR}/graph/operator.cc + ${METADEF_DIR}/graph/operator_factory.cc + ${METADEF_DIR}/graph/operator_factory_impl.cc + ${METADEF_DIR}/graph/operator_impl.cc + ${METADEF_DIR}/graph/graph.cc + ${METADEF_DIR}/graph/gnode.cc + ${METADEF_DIR}/graph/format_refiner.cc + ${METADEF_DIR}/graph/inference_context.cc + ${METADEF_DIR}/graph/ref_relation.cc + ${METADEF_DIR}/graph/resource_context_mgr.cc + ${METADEF_DIR}/graph/runtime_inference_context.cc + ${METADEF_DIR}/graph/shape_refiner.cc + ${METADEF_DIR}/graph/debug/graph_debug.cc + ${METADEF_DIR}/graph/opsproto/opsproto_manager.cc + ${METADEF_DIR}/graph/utils/op_desc_utils.cc + #${METADEF_DIR}/graph/utils/oper_utils.cc + ${METADEF_DIR}/graph/utils/tuning_utils.cc + ${METADEF_DIR}/graph/utils/ffts_graph_utils.cc + ${METADEF_DIR}/graph/utils/transformer_utils.cc + ${METADEF_DIR}/third_party/transformer/src/axis_util.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cc + ${METADEF_DIR}/third_party/transformer/src/expand_dimension.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_range_according_to_format.cc + ${METADEF_DIR}/third_party/transformer/src/transfer_shape_utils.cc + ${METADEF_DIR}/ops/op_imp.cpp +) + + ######### libgraph.so ############# add_library(graph SHARED ${GRAPH_SOURCE_LIST} + ${OPERATOR_SRC_LIST} $ ) @@ -151,7 +158,52 @@ target_compile_options(graph_static PRIVATE $<$,libgraph,graph> - ) +) + + +######### liboperator.so ############# +add_library(operator SHARED + ${OPERATOR_SRC_LIST} +) + +target_compile_options(operator PRIVATE + $<$,$>:-fexceptions> + $<$,$>: -fno-common -Wextra -Wfloat-equal> + -O2 +) + +target_compile_definitions(operator PRIVATE + $<$,$>:FMK_SUPPORT_DUMP> + $<$:ONLY_COMPILE_OPEN_SRC> +) + +target_include_directories(operator PRIVATE + ${METADEF_DIR} + ${CMAKE_BINARY_DIR} + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR}/third_party/transformer/inc +) + +target_link_options(operator PRIVATE + -Wl,-Bsymbolic +) + +target_link_libraries(operator + PRIVATE + $ + $<$>:$> + $<$>:$> + fwk_mmpa_headers + graphengine_headers + graph + -Wl,--as-needed + c_sec + slog + -Wl,--no-as-needed + PUBLIC + metadef_headers +) + ############################################################## add_custom_command( diff --git a/graph/compute_graph.cc b/graph/compute_graph.cc index 2fcb58f30190248cf46436c2778206dd1939426b..5b820382bd1815fac4a24f2aa9e9c2896e641b1d 100644 --- a/graph/compute_graph.cc +++ b/graph/compute_graph.cc @@ -641,19 +641,19 @@ void ComputeGraphImpl::SetAllSubgraphs(const std::vector ComputeGraphImpl::GetParentGraph() const { +std::shared_ptr ComputeGraphImpl::GetParentGraph() const { return parent_graph_.lock(); } -void ComputeGraphImpl::SetParentGraph(const shared_ptr &parent) { +void ComputeGraphImpl::SetParentGraph(const std::shared_ptr &parent) { parent_graph_ = parent; } -shared_ptr ComputeGraphImpl::GetParentNode() const { +std::shared_ptr ComputeGraphImpl::GetParentNode() const { return parent_node_.lock(); } -void ComputeGraphImpl::SetParentNode(const shared_ptr &parent) { +void ComputeGraphImpl::SetParentNode(const std::shared_ptr &parent) { parent_node_ = parent; } diff --git a/graph/compute_graph_impl.h b/graph/compute_graph_impl.h index d2d7c384293334571162c379ce9ab94f2524333f..dfb740ed33f57c927245716d3be02fd55fbf8b04 100644 --- a/graph/compute_graph_impl.h +++ b/graph/compute_graph_impl.h @@ -82,10 +82,10 @@ class ComputeGraphImpl { std::vector> GetAllSubgraphs() const; void SetAllSubgraphs(const std::vector> &subgraphs); - shared_ptr GetParentGraph() const; - void SetParentGraph(const shared_ptr &parent); - shared_ptr GetParentNode() const; - void SetParentNode(const shared_ptr &parent); + std::shared_ptr GetParentGraph() const; + void SetParentGraph(const std::shared_ptr &parent); + std::shared_ptr GetParentNode() const; + void SetParentNode(const std::shared_ptr &parent); const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } diff --git a/graph/ge_attr_value.cc b/graph/ge_attr_value.cc index 3cd3d6bf0a420f42c5d06e539f144ad670787585..9e1ea65aae0a23d46b0eb41a1c105bb29bf57d8e 100644 --- a/graph/ge_attr_value.cc +++ b/graph/ge_attr_value.cc @@ -30,7 +30,7 @@ #include "debug/ge_util.h" #include "graph/utils/tensor_utils.h" #include "graph/serialization/attr_serializer_registry.h" -#include "graph/utils/op_desc_utils.h" +#include "graph/utils/graph_utils.h" #include "graph/utils/math_util.h" namespace ge { @@ -94,11 +94,11 @@ bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, ui } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - return OpDescUtils::CloneOpDesc(org_op_desc); + return GraphUtils::CloneOpDesc(org_op_desc); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - return OpDescUtils::CopyOpDesc(org_op_desc); + return GraphUtils::CopyOpDesc(org_op_desc); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY diff --git a/graph/graph.cc b/graph/graph.cc index 4d2a54f5da5b39b198ac5479e753ef4bee6282d4..496cd1020b07d9cae6a9f7a99eeccd8333a2e7a2 100644 --- a/graph/graph.cc +++ b/graph/graph.cc @@ -25,15 +25,11 @@ #include "graph/utils/node_adapter.h" #include "graph/utils/node_utils.h" - -namespace { -const uint32_t kSubgraphIndexOfPartitionedCall = 0U; -} // namespace - namespace ge { class GraphImpl { public: friend class GraphUtils; + friend class OperUtils; GraphImpl(const GraphImpl &) = delete; GraphImpl &operator=(const GraphImpl &) = delete; @@ -943,25 +939,6 @@ GraphUtils::CreateGraphPtrFromComputeGraph(const ge::ComputeGraphPtr compute_gra return graph; } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::GetIndependentCompileGraphs(const ComputeGraphPtr &compute_graph, - std::vector &independent_compile_subgraphs) { - bool is_pipeline_partitioned = false; - (void)AttrUtils::GetBool(*compute_graph, ATTR_NAME_PIPELINE_PARTITIONED, is_pipeline_partitioned); - if (is_pipeline_partitioned) { - for (const auto &node : compute_graph->GetDirectNode()) { - if (node->GetType() == PARTITIONEDCALL) { - auto sub_graph = NodeUtils::GetSubgraph(*node, kSubgraphIndexOfPartitionedCall); - GE_CHECK_NOTNULL(sub_graph); - independent_compile_subgraphs.emplace_back(sub_graph); - } - } - return GRAPH_SUCCESS; - } - independent_compile_subgraphs.emplace_back(compute_graph); - return GRAPH_SUCCESS; -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { GE_CHECK_NOTNULL(graph.impl_); diff --git a/graph/op_desc.cc b/graph/op_desc.cc index 5ed90477cf732990d384f2480da393d3313f5e29..336e5b4aac65c8de5feee28fb3cfc0579309e97d 100644 --- a/graph/op_desc.cc +++ b/graph/op_desc.cc @@ -24,9 +24,11 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/transformer_utils.h" #include "graph/utils/node_utils.h" +#include "graph/utils/mem_utils.h" #include "graph/debug/ge_attr_define.h" #include "register/op_tiling/op_tiling_constants.h" #include "common/util/trace_manager/trace_manager.h" + namespace { using std::make_pair; using std::shared_ptr; @@ -792,7 +794,7 @@ graphStatus OpDescImpl::AddRegisterInputName(const std::string &name) { return GRAPH_SUCCESS; } -vector OpDescImpl::GetRegisterInputName() const { +std::vector OpDescImpl::GetRegisterInputName() const { return register_input_name_; } @@ -832,7 +834,7 @@ graphStatus OpDescImpl::AddRegisterOutputName(const std::string &name) { return GRAPH_SUCCESS; } -vector OpDescImpl::GetRegisterOutputName() const { +std::vector OpDescImpl::GetRegisterOutputName() const { return register_output_name_; } @@ -1123,111 +1125,111 @@ int64_t OpDescImpl::GetStreamId() const { return meta_data_.stream_id_; } -void OpDescImpl::SetInputName(const vector &input_name) { +void OpDescImpl::SetInputName(const std::vector &input_name) { meta_data_.input_names_ = input_name; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "input_name", "", "", ""); } -vector OpDescImpl::GetInputName() const { +std::vector OpDescImpl::GetInputName() const { return meta_data_.input_names_; } -void OpDescImpl::SetSrcName(const vector &src_name) { +void OpDescImpl::SetSrcName(const std::vector &src_name) { meta_data_.src_names_ = src_name; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "src_name", "", "", ""); } -vector OpDescImpl::GetSrcName() const { +std::vector OpDescImpl::GetSrcName() const { return meta_data_.src_names_; } -void OpDescImpl::SetSrcIndex(const vector &src_index) { +void OpDescImpl::SetSrcIndex(const std::vector &src_index) { meta_data_.src_indexes_ = src_index; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "src_index", "", "", ""); } -vector OpDescImpl::GetSrcIndex() const { +std::vector OpDescImpl::GetSrcIndex() const { return meta_data_.src_indexes_; } -void OpDescImpl::SetInputOffset(const vector &input) { +void OpDescImpl::SetInputOffset(const std::vector &input) { meta_data_.input_offsets_ = input; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "input_offset", "", "", ""); } -vector OpDescImpl::GetInputOffset() const { +std::vector OpDescImpl::GetInputOffset() const { return meta_data_.input_offsets_; } -void OpDescImpl::SetOutputOffset(const vector &output) { +void OpDescImpl::SetOutputOffset(const std::vector &output) { meta_data_.output_offsets_ = output; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "out_offset", "", "", ""); } -vector OpDescImpl::GetOutputOffset() const { +std::vector OpDescImpl::GetOutputOffset() const { return meta_data_.output_offsets_; } -void OpDescImpl::SetDstName(const vector &dst_name) { +void OpDescImpl::SetDstName(const std::vector &dst_name) { meta_data_.dst_names_ = dst_name; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "dst_name", "", "", ""); } -vector OpDescImpl::GetDstName() const { +std::vector OpDescImpl::GetDstName() const { return meta_data_.dst_names_; } -void OpDescImpl::SetDstIndex(const vector &dst_index) { +void OpDescImpl::SetDstIndex(const std::vector &dst_index) { meta_data_.dst_indexes_ = dst_index; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "dst_index", "", "", ""); } -void OpDescImpl::SetWorkspace(const vector &workspace) { +void OpDescImpl::SetWorkspace(const std::vector &workspace) { meta_data_.workspaces.assign(workspace.cbegin(), workspace.cend()); } -vector OpDescImpl::GetWorkspace() const { - vector res(meta_data_.workspaces.size()); +std::vector OpDescImpl::GetWorkspace() const { + std::vector res(meta_data_.workspaces.size()); for (size_t i = 0UL; i < meta_data_.workspaces.size(); ++i) { res[i] = meta_data_.workspaces[i]; } return res; } -void OpDescImpl::SetWorkspaceBytes(const vector &workspace_bytes) { +void OpDescImpl::SetWorkspaceBytes(const std::vector &workspace_bytes) { meta_data_.workspace_bytes_list_.assign(workspace_bytes.cbegin(), workspace_bytes.cend()); } -vector OpDescImpl::GetWorkspaceBytes() const { - vector res(meta_data_.workspace_bytes_list_.size()); +std::vector OpDescImpl::GetWorkspaceBytes() const { + std::vector res(meta_data_.workspace_bytes_list_.size()); for (size_t i = 0UL; i < meta_data_.workspace_bytes_list_.size(); ++i) { res[i] = meta_data_.workspace_bytes_list_[i]; } return res; } -void OpDescImpl::SetIsInputConst(const vector &is_input_const) { +void OpDescImpl::SetIsInputConst(const std::vector &is_input_const) { meta_data_.is_input_consts_ = is_input_const; TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), "is_input_const", "", "", ""); } -vector OpDescImpl::GetIsInputConst() const { +std::vector OpDescImpl::GetIsInputConst() const { return meta_data_.is_input_consts_; } @@ -1658,7 +1660,7 @@ graphStatus OpDesc::AddRegisterInputName(const std::string &name) { return impl_->AddRegisterInputName(name); } -vector OpDesc::GetRegisterInputName() const { +std::vector OpDesc::GetRegisterInputName() const { return impl_->GetRegisterInputName(); } @@ -1677,7 +1679,7 @@ graphStatus OpDesc::AddRegisterOutputName(const std::string &name) { return impl_->AddRegisterOutputName(name); } -vector OpDesc::GetRegisterOutputName() const { +std::vector OpDesc::GetRegisterOutputName() const { return impl_->GetRegisterOutputName(); } @@ -1752,6 +1754,18 @@ std::function OpDesc::GetVerifyFunc() const { return impl_->GetVerifyFunc(); } +std::function OpDesc::GetInferFormatFunc() const { + return impl_->GetInferFormatFunc(); +} + +std::function OpDesc::GetInferDataSliceFunc() const { + return impl_->GetInferDataSliceFunc(); +} + +std::function OpDesc::GetInferValueRangeFunc() const { + return impl_->GetInferValueRangeFunc(); +} + void OpDesc::AddInferFunc(const std::function &func) { impl_->AddInferFunc(func); } @@ -2018,4 +2032,138 @@ void OpDesc::AppendIrInput(std::string name, IrInputType input_type) { const std::vector> &OpDesc::GetIrInputs() const { return impl_->GetIrInputs(); } + +/** + * @brief Add input + * @param [in] name + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { + inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/** + * @brief Add input + * @param [in] name + * @param [in] tensor + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, const GeTensorDesc &tensor) { + inputs_.emplace_back(std::make_pair(name, tensor)); + return *this; +} + +/** + * @brief Add dynamic input + * @param [in] name + * @param [in] num + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, + const uint32_t num) { + for (uint32_t i = 0U; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/** + * @brief Add dynamic input + * @param [in] name + * @param [in] num + * @param [in] tensor + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0U; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); + } + return *this; +} + +/** + * @brief Add output + * @param [in] name + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { + outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/** + * @brief Add output + * @param [in] name + * @param [in] tensor + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, const GeTensorDesc &tensor) { + outputs_.emplace_back(std::make_pair(name, tensor)); + return *this; +} + +/** + * @brief Add dynamic output + * @param [in] name + * @param [in] num + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, + const uint32_t num) { + for (uint32_t i = 0U; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/** + * @brief Add dynamic output + * @param [in] name + * @param [in] num + * @param [in] tensor + * @return OpDescBuilder + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, const uint32_t num, + const GeTensorDesc &tensor) { + for (uint32_t i = 0U; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); + } + return *this; +} + +/** + * @brief Build op_desc + * @return OpDescPtr + */ +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { + const OpDescPtr op_desc = MakeShared(name_, type_); + if (op_desc == nullptr) { + REPORT_CALL_ERROR("E18888", "create opdesc failed, name:%s, type:%s.", name_.c_str(), type_.c_str()); + GELOGE(GRAPH_FAILED, "[Create][OpDesc] failed, name:%s, type:%s.", name_.c_str(), type_.c_str()); + return nullptr; + } + + for (auto &input : inputs_) { + if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E18888", "AddInputDesc failed, op:%s.", name_.c_str()); + GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed, op:%s.", name_.c_str()); + return nullptr; + } + } + + for (auto &output : outputs_) { + if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E18888", "AddOutputDesc failed, op:%s", name_.c_str()); + GELOGE(GRAPH_FAILED, "[Add][OutputDesc] failed, op:%s.", name_.c_str()); + return nullptr; + } + } + + return op_desc; +} } // namespace ge diff --git a/graph/op_desc_impl.h b/graph/op_desc_impl.h index ed4994c8591a2d65fb638aad21ff3e8f14763b2d..6a4cf0b72ec1ac2a44bf13955887ee2a58c8529c 100644 --- a/graph/op_desc_impl.h +++ b/graph/op_desc_impl.h @@ -33,16 +33,16 @@ class MetaDataStore { MetaDataStore(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} int64_t GetId() const {return id_;} int64_t GetStreamId() const {return stream_id_;} - const vector &GetInputNames() const {return input_names_;} - const vector &GetSrcNames() const {return src_names_;} - const vector &GetSrcIndexes() const {return src_indexes_;} - const vector &GetDstNames() const {return dst_names_;} - const vector &GetDstIndexes() const {return dst_indexes_;} - const vector &GetInputOffsets() const {return input_offsets_;} - const vector &GetOutputOffsets() const {return output_offsets_;} - const vector &GetIsInputConsts() const {return is_input_consts_;} - const vector &GetSubgraphNames() const {return subgraph_names_;} - void AddSubGraphName(const string &name) {subgraph_names_.push_back(name);} + const std::vector &GetInputNames() const {return input_names_;} + const std::vector &GetSrcNames() const {return src_names_;} + const std::vector &GetSrcIndexes() const {return src_indexes_;} + const std::vector &GetDstNames() const {return dst_names_;} + const std::vector &GetDstIndexes() const {return dst_indexes_;} + const std::vector &GetInputOffsets() const {return input_offsets_;} + const std::vector &GetOutputOffsets() const {return output_offsets_;} + const std::vector &GetIsInputConsts() const {return is_input_consts_;} + const std::vector &GetSubgraphNames() const {return subgraph_names_;} + void AddSubGraphName(const std::string &name) {subgraph_names_.push_back(name);} void ClearSubgraphNames() { subgraph_names_.clear(); } private: @@ -101,7 +101,7 @@ class OpDescImpl { const GeTensorDesc &GetInputDesc(const std::string &name) const; GeTensorDescPtr MutableInputDesc(const uint32_t index) const; GeTensorDescPtr MutableInputDesc(const std::string &name) const; - OpDesc::Vistor GetAllInputNames(const ConstOpDescPtr &op_desc) const; + OpDesc::Vistor GetAllInputNames(const ConstOpDescPtr &op_desc) const; void SetOpKernelLibName(const std::string &name); std::string GetOpKernelLibName() const; @@ -154,6 +154,10 @@ class OpDescImpl { std::function GetInferFunc() const; std::function GetVerifyFunc() const; + std::function GetInferFormatFunc() const { return infer_format_func_; } + std::function GetInferValueRangeFunc() const { return infer_value_range_func_; } + std::function GetInferDataSliceFunc() const { return infer_data_slice_func_; } + void AddInferFunc(const std::function &func); void AddInferFormatFunc(const std::function &func); void AddVerifierFunc(const std::function &func); @@ -243,6 +247,7 @@ class OpDescImpl { friend class ModelSerializeImp; friend class OnnxUtils; friend class GraphUtils; + friend class NodeUtils; std::vector subgraph_instance_names_; // subgraph names to index, for a `if` operator: diff --git a/graph/operator.cc b/graph/operator.cc index a2219359735121a5a62e7a1ce66d15346d3855d3..563dc6820b4100c9de987e936e5b67372f3fcb1b 100644 --- a/graph/operator.cc +++ b/graph/operator.cc @@ -375,7 +375,7 @@ OpDescUtils::CopyOperators(const ComputeGraphPtr &dst_compute_graph, dst_op_desc->CopyAttrsFrom(*scr_op_impl_ptr->op_desc_); dst_op_desc->SetName(scr_op_impl_ptr->op_desc_->GetName()); } - dst_op = CreateOperatorFromOpDesc(dst_op_desc); + dst_op = OpDescUtils::CreateOperatorFromOpDesc(dst_op_desc); } else { const auto original_op_desc = scr_op_impl_ptr->node_->GetOpDesc(); if (scr_op_impl_ptr->op_desc_ != original_op_desc) { @@ -402,7 +402,7 @@ OpDescUtils::CopyOperators(const ComputeGraphPtr &dst_compute_graph, GE_CHECK_NOTNULL(dst_node); // to do link egdes } - dst_op = CreateOperatorFromNode(dst_node); + dst_op = OpDescUtils::CreateOperatorFromNode(dst_node); (void)(all_node_info.emplace(dst_op.GetOperatorImplPtr(), dst_node)); } dst_op.operator_impl_->subgraph_names_to_builders_ = src_op.operator_impl_->subgraph_names_to_builders_; diff --git a/graph/opsproto/opsproto_manager.cc b/graph/opsproto/opsproto_manager.cc index 823e2e954a89af7b253bffc5e061ee8d13a4e2e2..0e0ec76fd7bc16f0b79f8362f497814cb1b0ef55 100644 --- a/graph/opsproto/opsproto_manager.cc +++ b/graph/opsproto/opsproto_manager.cc @@ -20,7 +20,7 @@ #include #include #include -#include "debug/ge_util.h" +#include "graph/debug/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_log.h" #include "graph/types.h" diff --git a/graph/shape_refiner.cc b/graph/shape_refiner.cc index c39da59d995b6ed046f18de99ef71fcfa3283025..028bae2f6447ee5f5d26cd5f8bc2a8e073678a77 100644 --- a/graph/shape_refiner.cc +++ b/graph/shape_refiner.cc @@ -26,7 +26,6 @@ #include "debug/ge_log.h" #include "debug/ge_op_types.h" -#include "debug/ge_util.h" #include "external/graph/operator_factory.h" #include "graph/operator_factory_impl.h" #include "graph/utils/node_utils.h" diff --git a/graph/utils/graph_utils.cc b/graph/utils/graph_utils.cc index c60201704517b9f7c8bf2a96d264da2bf6a4cbc5..2047446fa17105a1cc888b13d3c6258391841ca2 100644 --- a/graph/utils/graph_utils.cc +++ b/graph/utils/graph_utils.cc @@ -39,8 +39,8 @@ #include "graph/debug/ge_op_types.h" #include "external/ge/ge_api_types.h" #include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" +#include "graph/detail/model_serialize_imp.h" #include "graph/compute_graph_impl.h" #include "graph/op_desc_impl.h" #include "mmpa/mmpa_api.h" @@ -71,8 +71,28 @@ const char_t *const kDumpStrAicpu = "Aicpu"; const size_t kNameMax = 255U; const int32_t kCopyGraphMaxRecursionDepth = 10; const int32_t kNameWidth = 5; +const uint32_t kSubgraphIndexOfPartitionedCall = 0U; const std::set kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; -}; +} // namespace + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +graphStatus GraphUtils::GetIndependentCompileGraphs(const ComputeGraphPtr &compute_graph, + std::vector &independent_compile_subgraphs) { + bool is_pipeline_partitioned = false; + (void)AttrUtils::GetBool(*compute_graph, ATTR_NAME_PIPELINE_PARTITIONED, is_pipeline_partitioned); + if (is_pipeline_partitioned) { + for (const auto &node : compute_graph->GetDirectNode()) { + if (node->GetType() == PARTITIONEDCALL) { + auto sub_graph = NodeUtils::GetSubgraph(*node, kSubgraphIndexOfPartitionedCall); + GE_CHECK_NOTNULL(sub_graph); + independent_compile_subgraphs.emplace_back(sub_graph); + } + } + return GRAPH_SUCCESS; + } + independent_compile_subgraphs.emplace_back(compute_graph); + return GRAPH_SUCCESS; +} GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { @@ -1685,6 +1705,72 @@ graphStatus GraphUtils::CopyComputeGraph(const ComputeGraphPtr &src_compute_grap return GRAPH_SUCCESS; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr GraphUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { + GE_CHECK_NOTNULL_EXEC(org_op_desc, return nullptr); + const auto op_def = ComGraphMakeShared(); + GE_CHECK_NOTNULL_EXEC(op_def, return nullptr); + + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); + + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), + REPORT_CALL_ERROR("E18888", "UnserializeOpDesc failed"); + return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); + + GE_CHECK_NOTNULL_EXEC(op_desc->impl_, return nullptr); + op_desc->ext_attrs_ = org_op_desc->ext_attrs_; + + // This function may be called by some passes of fusion engine, in this condition, do not need these attribute + if (!op_desc->impl_->input_name_idx_.empty()) { + op_desc->impl_->input_name_idx_.clear(); + } + if (!op_desc->impl_->output_name_idx_.empty()) { + op_desc->impl_->output_name_idx_.clear(); + } + if (!op_desc->impl_->optional_input_names_.empty()) { + op_desc->impl_->optional_input_names_.clear(); + } + + return op_desc; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr GraphUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { + if ((org_op_desc == nullptr) || (org_op_desc->impl_ == nullptr)) { + REPORT_INNER_ERROR("E18888", "org_op_desc is null, check invalid"); + GELOGE(GRAPH_FAILED, "[Check][Param] org_op_desc is null"); + return nullptr; + } + const auto op_def = ComGraphMakeShared(); + GE_CHECK_NOTNULL_EXEC(op_def, return nullptr); + + ModelSerializeImp imp; + (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); + + imp.SetProtobufOwner(op_def); + OpDescPtr op_desc = nullptr; + if (!imp.UnserializeOpDesc(op_desc, *op_def)) { + REPORT_CALL_ERROR("E18888", "UnserializeOpDesc failed."); + return nullptr; + } + + GE_CHECK_NOTNULL_EXEC(op_desc->impl_, return nullptr); + op_desc->ext_attrs_ = org_op_desc->ext_attrs_; + op_desc->impl_->input_name_idx_.insert(org_op_desc->impl_->input_name_idx_.cbegin(), + org_op_desc->impl_->input_name_idx_.cend()); + op_desc->impl_->optional_input_names_.insert(org_op_desc->impl_->optional_input_names_.cbegin(), + org_op_desc->impl_->optional_input_names_.cend()); + op_desc->impl_->output_name_idx_.insert(org_op_desc->impl_->output_name_idx_.cbegin(), + org_op_desc->impl_->output_name_idx_.cend()); + + op_desc->impl_->infer_func_ = org_op_desc->impl_->infer_func_; + op_desc->impl_->infer_format_func_ = org_op_desc->impl_->infer_format_func_; + op_desc->impl_->verifier_func_ = org_op_desc->impl_->verifier_func_; + + return op_desc; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOpAndSubgraph(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph, @@ -1699,7 +1785,7 @@ graphStatus GraphUtils::CopyOpAndSubgraph(const ComputeGraphPtr &src_compute_gra const auto src_root_compute_graph = FindRootGraph(src_compute_graph); GE_CHECK_NOTNULL(src_root_compute_graph); for (const auto &n : src_compute_graph->GetDirectNode()) { - const OpDescPtr op_desc = OpDescUtils::CopyOpDesc(n->GetOpDesc()); + const OpDescPtr op_desc = GraphUtils::CopyOpDesc(n->GetOpDesc()); if ((op_desc == nullptr) || (op_desc->impl_ == nullptr)) { REPORT_CALL_ERROR("E18888", "CopyOpDesc failed from node:%s", n->GetName().c_str()); GELOGE(GRAPH_FAILED, "[Copy][OpDesc] from node:%s failed", n->GetName().c_str()); @@ -1913,7 +1999,7 @@ ComputeGraphPtr GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std:: std::unordered_map all_new_nodes; for (const auto &n : graph->GetDirectNode()) { - const OpDescPtr op_desc = OpDescUtils::CopyOpDesc(n->GetOpDesc()); + const OpDescPtr op_desc = GraphUtils::CopyOpDesc(n->GetOpDesc()); GE_CHK_BOOL_EXEC(op_desc != nullptr, REPORT_CALL_ERROR("E18888", "Create node:%s failed.", n->GetOpDesc()->GetName().c_str()); return nullptr, "[Create][Node] %s failed", n->GetOpDesc()->GetName().c_str()); @@ -2851,7 +2937,7 @@ ComputeGraphPtr GraphUtils::BuildSubgraph(const NodePtr &subgraph_node, const Gr // Add node for (const auto &node : graph_info.nodes_) { - (void)graph_builder.AddNode(OpDescUtils::CopyOpDesc(node->GetOpDesc())); + (void)graph_builder.AddNode(GraphUtils::CopyOpDesc(node->GetOpDesc())); } // Set Input diff --git a/graph/utils/node_utils.cc b/graph/utils/node_utils.cc index 51b445091aa1ed5a1ae30763d93eacb5220944dc..0e6bd1af9403b843222483f866b32d0499f8c1af 100644 --- a/graph/utils/node_utils.cc +++ b/graph/utils/node_utils.cc @@ -14,16 +14,17 @@ * limitations under the License. */ #include "graph/utils/node_utils.h" + #include -#include -#include "graph/utils/op_desc_utils.h" + +#include "securec.h" #include "graph/utils/graph_utils.h" #include "graph/debug/ge_op_types.h" #include "graph/debug/ge_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/node_impl.h" +#include "graph/op_desc_impl.h" #include "graph/ge_context.h" -#include "graph/runtime_inference_context.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_adapter.h" #include "graph/utils/type_utils.h" @@ -297,6 +298,47 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInpu return GRAPH_SUCCESS; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::ClearInputDesc(const OpDescPtr &op_desc, + const uint32_t index) { + GE_CHK_BOOL_EXEC((op_desc != nullptr) && (op_desc->impl_ != nullptr), + REPORT_INNER_ERROR("E18888", "op_desc is nullptr, check invalid"); + return false, "[Check][Param] op_desc is nullptr"); + GE_CHK_BOOL_EXEC(index < op_desc->impl_->inputs_desc_.size(), + REPORT_INNER_ERROR("E18888", "index %u is invalid, out of range(0, %zu).", + index, op_desc->impl_->inputs_desc_.size()); + return false, + "[Check][Param] index %u is invalid, out of range(0, %zu).", + index, op_desc->impl_->inputs_desc_.size()); + + const auto iter = op_desc->impl_->inputs_desc_.begin() + static_cast(index); + if (iter < op_desc->impl_->inputs_desc_.end()) { + (void)op_desc->impl_->inputs_desc_.erase(iter); + } else { + GELOGW("[Clear][InputDesc] inputs_desc_ iterator out of range."); + } + return true; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::ClearOutputDesc(const OpDescPtr &op_desc, + const uint32_t index) { + GE_CHK_BOOL_EXEC((op_desc != nullptr) && (op_desc->impl_ != nullptr), + REPORT_INNER_ERROR("E18888", "param op_desc is nullptr, check invalid"); + return false, "[Check][Param] op_desc is nullptr"); + GE_CHK_BOOL_EXEC(index < op_desc->impl_->outputs_desc_.size(), + REPORT_INNER_ERROR("E18888", "index %u is invalid. out of range(0, %zu)", + index, op_desc->impl_->outputs_desc_.size()); + return false, + "[Check][Param] index %u is invalid. out of range(0, %zu)", + index, op_desc->impl_->outputs_desc_.size()); + const auto iter = op_desc->impl_->outputs_desc_.begin() + static_cast(index); + if (iter < op_desc->impl_->outputs_desc_.end()) { + (void)op_desc->impl_->outputs_desc_.erase(iter); + } else { + GELOGW("[Clear][OutputDesc] outputs_desc_ iterator out of range."); + } + return true; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, const uint32_t num) { if ((node == nullptr) || (node->impl_ == nullptr)) { @@ -307,7 +349,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInpu const auto &op_desc = node->GetOpDesc(); while (op_desc->GetInputsSize() > num) { - if (!OpDescUtils::ClearInputDesc(op_desc, num)) { + if (!NodeUtils::ClearInputDesc(op_desc, num)) { return GRAPH_FAILED; } } @@ -367,7 +409,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutp const auto &op_desc = node->GetOpDesc(); const auto output_names = op_desc->GetAllOutputName(); while (op_desc->GetOutputsSize() > num) { - if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { + if (!NodeUtils::ClearOutputDesc(op_desc, num)) { return GRAPH_FAILED; } } diff --git a/graph/utils/op_desc_utils.cc b/graph/utils/op_desc_utils.cc index 8171b767301d1be328df9216e273086ad532150d..e371846d0d7cc004e0b68b33212cb13b7a6efca7 100644 --- a/graph/utils/op_desc_utils.cc +++ b/graph/utils/op_desc_utils.cc @@ -27,8 +27,6 @@ #include "graph/utils/node_utils.h" #include "graph/utils/constant_utils.h" #include "graph/operator_impl.h" -#include "proto/ge_ir.pb.h" -#include "graph/detail/model_serialize_imp.h" /*lint -e512 -e737 -e752*/ namespace ge { @@ -69,23 +67,7 @@ bool OpDescUtils::ClearInputDesc(const NodePtr &node) { GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(const OpDescPtr op_desc, const uint32_t index) { - GE_CHK_BOOL_EXEC((op_desc != nullptr) && (op_desc->impl_ != nullptr), - REPORT_INNER_ERROR("E18888", "op_desc is nullptr, check invalid"); - return false, "[Check][Param] op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->impl_->inputs_desc_.size(), - REPORT_INNER_ERROR("E18888", "index %u is invalid, out of range(0, %zu).", - index, op_desc->impl_->inputs_desc_.size()); - return false, - "[Check][Param] index %u is invalid, out of range(0, %zu).", - index, op_desc->impl_->inputs_desc_.size()); - - const auto iter = op_desc->impl_->inputs_desc_.begin() + static_cast(index); - if (iter < op_desc->impl_->inputs_desc_.end()) { - (void)op_desc->impl_->inputs_desc_.erase(iter); - } else { - GELOGW("[Clear][InputDesc] inputs_desc_ iterator out of range."); - } - return true; + return NodeUtils::ClearInputDesc(op_desc, index); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { @@ -127,22 +109,7 @@ bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, const uint32_t index) { - GE_CHK_BOOL_EXEC((op_desc != nullptr) && (op_desc->impl_ != nullptr), - REPORT_INNER_ERROR("E18888", "param op_desc is nullptr, check invalid"); - return false, "[Check][Param] op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->impl_->outputs_desc_.size(), - REPORT_INNER_ERROR("E18888", "index %u is invalid. out of range(0, %zu)", - index, op_desc->impl_->outputs_desc_.size()); - return false, - "[Check][Param] index %u is invalid. out of range(0, %zu)", - index, op_desc->impl_->outputs_desc_.size()); - const auto iter = op_desc->impl_->outputs_desc_.begin() + static_cast(index); - if (iter < op_desc->impl_->outputs_desc_.end()) { - (void)op_desc->impl_->outputs_desc_.erase(iter); - } else { - GELOGW("[Clear][OutputDesc] outputs_desc_ iterator out of range."); - } - return true; + return NodeUtils::ClearOutputDesc(op_desc, index); } bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } @@ -298,7 +265,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpD return ret; } -vector OpDescUtils::GetWeightsFromNodes( +std::vector OpDescUtils::GetWeightsFromNodes( const std::vector &input_nodes_2_out_anchors) { std::vector ret; for (const auto &input_node_2_anchor : input_nodes_2_out_anchors) { @@ -746,69 +713,11 @@ OpDescUtils::SetWeights(ge::Node &node, const std::map &we } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - GE_CHECK_NOTNULL_EXEC(org_op_desc, return nullptr); - const auto op_def = ComGraphMakeShared(); - GE_CHECK_NOTNULL_EXEC(op_def, return nullptr); - - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), - REPORT_CALL_ERROR("E18888", "UnserializeOpDesc failed"); - return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); - - GE_CHECK_NOTNULL_EXEC(op_desc->impl_, return nullptr); - op_desc->ext_attrs_ = org_op_desc->ext_attrs_; - - // This function may be called by some passes of fusion engine, in this condition, do not need these attribute - if (!op_desc->impl_->input_name_idx_.empty()) { - op_desc->impl_->input_name_idx_.clear(); - } - if (!op_desc->impl_->output_name_idx_.empty()) { - op_desc->impl_->output_name_idx_.clear(); - } - if (!op_desc->impl_->optional_input_names_.empty()) { - op_desc->impl_->optional_input_names_.clear(); - } - - return op_desc; + return GraphUtils::CloneOpDesc(org_op_desc); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - if ((org_op_desc == nullptr) || (org_op_desc->impl_ == nullptr)) { - REPORT_INNER_ERROR("E18888", "org_op_desc is null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] org_op_desc is null"); - return nullptr; - } - const auto op_def = ComGraphMakeShared(); - GE_CHECK_NOTNULL_EXEC(op_def, return nullptr); - - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - if (!imp.UnserializeOpDesc(op_desc, *op_def)) { - REPORT_CALL_ERROR("E18888", "UnserializeOpDesc failed."); - return nullptr; - } - - GE_CHECK_NOTNULL_EXEC(op_desc->impl_, return nullptr); - op_desc->ext_attrs_ = org_op_desc->ext_attrs_; - op_desc->impl_->input_name_idx_.insert(org_op_desc->impl_->input_name_idx_.cbegin(), - org_op_desc->impl_->input_name_idx_.cend()); - op_desc->impl_->optional_input_names_.insert(org_op_desc->impl_->optional_input_names_.cbegin(), - org_op_desc->impl_->optional_input_names_.cend()); - op_desc->impl_->output_name_idx_.insert(org_op_desc->impl_->output_name_idx_.cbegin(), - org_op_desc->impl_->output_name_idx_.cend()); - - op_desc->impl_->infer_func_ = org_op_desc->impl_->infer_func_; - op_desc->impl_->infer_format_func_ = org_op_desc->impl_->infer_format_func_; - op_desc->impl_->verifier_func_ = org_op_desc->impl_->verifier_func_; - - return op_desc; + return GraphUtils::CopyOpDesc(org_op_desc); } OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { @@ -888,140 +797,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei return GRAPH_SUCCESS; } -/// -/// @brief Add input -/// @param [in] name -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddInput(const std::string &name) { - inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -/// -/// @brief Add input -/// @param [in] name -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddInput(const std::string &name, const GeTensorDesc &tensor) { - inputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -/// -/// @brief Add dynamic input -/// @param [in] name -/// @param [in] num -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, - const uint32_t num) { - for (uint32_t i = 0U; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -/// -/// @brief Add dynamic input -/// @param [in] name -/// @param [in] num -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor) { - for (uint32_t i = 0U; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -/// -/// @brief Add output -/// @param [in] name -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name) { - outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -/// -/// @brief Add output -/// @param [in] name -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name, const GeTensorDesc &tensor) { - outputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -/// -/// @brief Add dynamic output -/// @param [in] name -/// @param [in] num -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, - const uint32_t num) { - for (uint32_t i = 0U; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -/// -/// @brief Add dynamic output -/// @param [in] name -/// @param [in] num -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, const uint32_t num, - const GeTensorDesc &tensor) { - for (uint32_t i = 0U; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -/// -/// @brief Build op_desc -/// @return OpDescPtr -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { - const OpDescPtr op_desc = MakeShared(name_, type_); - if (op_desc == nullptr) { - REPORT_CALL_ERROR("E18888", "create opdesc failed, name:%s, type:%s.", name_.c_str(), type_.c_str()); - GELOGE(GRAPH_FAILED, "[Create][OpDesc] failed, name:%s, type:%s.", name_.c_str(), type_.c_str()); - return nullptr; - } - - for (auto &input : inputs_) { - if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E18888", "AddInputDesc failed, op:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed, op:%s.", name_.c_str()); - return nullptr; - } - } - - for (auto &output : outputs_) { - if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E18888", "AddOutputDesc failed, op:%s", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Add][OutputDesc] failed, op:%s.", name_.c_str()); - return nullptr; - } - } - - return op_desc; -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName(const std::string &subgraph_name, const std::string &subgraph_instance_name, diff --git a/graph/utils/oper_utils.cc b/graph/utils/oper_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8a91141339fe20b3055ef487b776a97302e21cb --- /dev/null +++ b/graph/utils/oper_utils.cc @@ -0,0 +1,366 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/utils/oper_utils.h" + +#include "framework/common/util.h" +#include "common/util/trace_manager/trace_manager.h" +#include "graph/format_refiner.h" +#include "graph/shape_refiner.h" +#include "graph/operator_impl.h" +#include "graph/operator_factory_impl.h" +#include "graph/common_error_codes.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_op_types.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/transformer_utils.h" +#include "graph/utils/mem_utils.h" + +namespace ge { +graphStatus OperUtils::CallInferFunc(const OpDescPtr &op_desc, Operator &op) { + GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Shape."); + auto infer_func = op_desc->GetInferFunc(); + if (infer_func == nullptr) { + infer_func = OperatorFactoryImpl::GetInferShapeFunc(op_desc->GetType()); + if (infer_func == nullptr) { + GELOGW("[InferShape][Check] %s does not have infer_func.", op_desc->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + } + + NodeShapeTransUtils transformer(op_desc); + const auto is_init_success = transformer.Init(); + if (!is_init_success) { + GELOGE(GRAPH_FAILED, "[Call][Init] for transformer failed"); + return GRAPH_FAILED; + } + if (!transformer.CatchFormatAndShape()) { + GELOGE(GRAPH_FAILED, "[Call][CatchFormatAndShape] for transformer failed!"); + return GRAPH_FAILED; + } + graphStatus graph_status = GRAPH_SUCCESS; + { + auto node_ptr = NodeUtils::GetNodeFromOperator(op); + TraceOwnerGuard guard("OP", op_desc->GetName() + ":infershape", + (node_ptr == nullptr) ? "" + : (node_ptr->GetOwnerComputeGraph() == nullptr) + ? std::string("") + : node_ptr->GetOwnerComputeGraph()->GetName()); + graph_status = infer_func(op); + } + if ((graph_status != GRAPH_SUCCESS) && (graph_status != GRAPH_NODE_NEED_REPASS)) { + GELOGE(GRAPH_FAILED, "[Call][InferFunc] for %s failed. ret:%u", op_desc->GetName().c_str(), graph_status); + return GRAPH_FAILED; + } + if (!transformer.UpdateFormatAndShape()) { + GELOGE(GRAPH_FAILED, "[Call][UpdateFormatAndShape] for transformer failed!"); + return GRAPH_FAILED; + } + return graph_status; +} + +graphStatus OperUtils::CallInferFormatFunc(const OpDescPtr &op_desc, Operator &op) { + GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Format."); + const auto infer_func = op_desc->GetInferFormatFunc(); + if (infer_func != nullptr) { + return static_cast(infer_func(op)); + } + + const InferFormatFunc infer_format_func = OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()); + if (infer_format_func == nullptr) { + return op_desc->DefaultInferFormat(); + } + + return infer_format_func(op); +} + +graphStatus OperUtils::CallInferValueRangeFunc(const OpDescPtr &op_desc, Operator &op) { + GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer ValueRange."); + const auto infer_func = op_desc->GetInferValueRangeFunc(); + if (infer_func != nullptr) { + return static_cast(infer_func(op)); + } + + const InferValueRangePara infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(op_desc->GetType()); + if (!infer_value_range_param.is_initialized) { + REPORT_CALL_ERROR("E18888", "Node %s does not register func to infer value range.", op_desc->GetName().c_str()); + GELOGE(GRAPH_PARAM_INVALID, "Node %s does not register func to infer value range.", op_desc->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + + if (infer_value_range_param.infer_value_func == nullptr) { + REPORT_CALL_ERROR("E18888", "Value range infer func of node %s has been registered, but infer func is nullptr.", + op_desc->GetName().c_str()); + GELOGE(GRAPH_PARAM_INVALID, "Value range infer func of node %s has been registered, but infer func is nullptr.", + op_desc->GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + + return infer_value_range_param.infer_value_func(op); +} + +graphStatus OperUtils::OpVerify(const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Verify."); + const VerifyFunc verify_func = OperatorFactoryImpl::GetVerifyFunc(op_desc->GetType()); + if (verify_func != nullptr) { + Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + const graphStatus ret = verify_func(op); + op.BreakConnect(); + return ret; + } + return GRAPH_SUCCESS; +} + +graphStatus OperUtils::InferShapeAndType(const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Shape."); + auto infer_func = op_desc->GetInferFunc(); + if (infer_func == nullptr) { + infer_func = OperatorFactoryImpl::GetInferShapeFunc(op_desc->GetType()); + if (infer_func == nullptr) { + GELOGW("[InferShape][Check] %s does not have infer_func.", op_desc->GetName().c_str()); + /// The infer_func has not been added for each operator in the current operator information library. + /// No infer_func added operator skips the call + /// and directly uses the shape information passed down by the upper framework + return GRAPH_SUCCESS; + } + } + + Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + const graphStatus ret = infer_func(op); + op.BreakConnect(); + return ret; +} + +graphStatus OperUtils::InferDataSlice(const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Slice."); + auto infer_data_slice_func = op_desc->GetInferDataSliceFunc(); + if (infer_data_slice_func == nullptr) { + infer_data_slice_func = OperatorFactoryImpl::GetInferDataSliceFunc(op_desc->GetType()); + if (infer_data_slice_func == nullptr) { + GELOGW("[InferDataSlice][Check] %s does not have infer data slice func.", op_desc->GetName().c_str()); + return NO_DEPENDENCE_FUNC; + } + } + + Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + const graphStatus ret = infer_data_slice_func(op); + op.BreakConnect(); + return ret; +} + +graphStatus OperUtils::InferShapeAndType(const NodePtr &node) { + GE_CHECK_NOTNULL(node, ", Node is null for Infer Shape."); + Operator op = OpDescUtils::CreateOperatorFromNode(node); + return ShapeRefiner::InferShapeAndType(node, op); +} + +graphStatus OperUtils::InferOriginFormat(const NodePtr &node) { + GE_CHECK_NOTNULL(node, ", Node is null for Infer Format."); + const auto op_desc = node->GetOpDesc(); + const InferFormatFunc infer_format_func = OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()); + if (infer_format_func == nullptr) { + return op_desc->DefaultInferFormat(); + } + + Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); + return infer_format_func(op); +} + +graphStatus OperUtils::IsInputsValid(const NodePtr &node) { + const auto &op_desc = node->GetOpDesc(); + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + if (in_anchor == nullptr) { + GELOGW("[Verify][CheckParam] In data anchor is null"); + continue; + } + const bool valid_anchor = (node->GetType() == DATA) || (node->GetType() == AIPPDATA) || + (node->GetType() == CONSTANT) || (node->GetType() == VARIABLE) || + (node->GetType() == CONSTANTOP) || + (op_desc->MutableInputDesc(static_cast(in_anchor->GetIdx())) == nullptr) || + (in_anchor->GetPeerAnchors().size() > 0UL); + if (!valid_anchor) { + ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, + {node->GetName(), std::to_string(in_anchor->GetIdx())}); + GELOGE(GRAPH_FAILED, "[Check][Param] operator %s's input %d is not linked.", + node->GetName().c_str(), in_anchor->GetIdx()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +graphStatus OperUtils::Verify(const NodePtr &node) { + GE_CHECK_NOTNULL(node, ", Node is null for Infer Verify."); + const bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + if (!is_unknown_graph) { + GE_CHK_STATUS_RET_NOLOG(IsInputsValid(node)); + } + + const auto op_desc = node->GetOpDesc(); + const bool need_update_name = (node->GetType() != FRAMEWORKOP) && (!is_unknown_graph); + if (need_update_name) { + const auto node_op = OperatorFactoryImpl::CreateOperator("node_op", node->GetType()); + if (node_op.IsEmpty()) { + GELOGW("[Verify][CheckParam] Get op from OperatorFactory failed, type: %s", node->GetType().c_str()); + } else { + GELOGD("get op from OperatorFactory success. opType: %s", node->GetType().c_str()); + const auto temp_op_desc = OpDescUtils::GetOpDescFromOperator(node_op); + if (temp_op_desc == nullptr) { + REPORT_INNER_ERROR("E18888", "GetOpDescFromOperator failed, as return nullptr, type:%s", + node->GetType().c_str()); + GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null, type:%s", node->GetType().c_str()); + return GRAPH_FAILED; + } + if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { + GELOGW("[Verify][Update] Update input name failed"); + } + if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { + GELOGW("[Verify][Update] Update output name failed"); + } + } + node_op.BreakConnect(); + } + + if (is_unknown_graph) { + return GRAPH_SUCCESS; + } + + if (op_desc->CommonVerify() == GRAPH_SUCCESS) { + Operator op = OpDescUtils::CreateOperatorFromNode(node); + auto verify_func = op_desc->GetVerifyFunc(); + if (verify_func == nullptr) { + verify_func = OperatorFactoryImpl::GetVerifyFunc(node->GetType()); + } + if (verify_func != nullptr) { + return verify_func(op); + } + return GRAPH_SUCCESS; + } else { + REPORT_CALL_ERROR("E18888", "%s(%s) Verify failed.", node->GetName().c_str(), node->GetType().c_str()); + GELOGE(GRAPH_FAILED, "[Call][CommonVerify] %s(%s) failed.", node->GetName().c_str(), node->GetType().c_str()); + return GRAPH_FAILED; + } +} + +graphStatus OperUtils::Verify(const ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph, ", Graph is null for Infer Shape."); + const bool is_unknown_graph = graph->GetGraphUnknownFlag(); + for (const auto &node_ptr : graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); + if (is_unknown_graph) { + continue; + } + if (node_ptr->GetOpDesc()->CommonVerify() != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E18888", "Verifying %s failed.", node_ptr->GetName().c_str()); + GELOGE(FAILED, "[Call][CommonVerify] Verifying %s failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + } + return GRAPH_SUCCESS; +} + +graphStatus OperUtils::InferOriginFormat(const ComputeGraphPtr &graph) { + return FormatRefiner::InferOrigineFormat(graph); +} + +graphStatus OperUtils::InferShapeInNeed(const ComputeGraphPtr &graph) { + GE_LOGW_IF(graph->TopologicalSorting() != GRAPH_SUCCESS, "Verify failed."); + for (const auto &node_ptr : graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + const auto op_desc = node_ptr->GetOpDesc(); + bool is_need_infer = false; + (void)AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer); + if (is_need_infer) { + if (OperUtils::Verify(node_ptr) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E18888", "Verifying %s failed.", node_ptr->GetName().c_str()); + GELOGE(FAILED, "[Call][Verify] Verifying %s failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + const graphStatus status = OperUtils::InferShapeAndType(node_ptr); + if ((node_ptr->GetType() != DATA) && (status == GRAPH_PARAM_INVALID)) { + GELOGI("Op %s does not have the IMPLEMT_INFERFUNC definition, " + "and subsequent operators no longer perform shape inference.", + node_ptr->GetName().c_str()); + break; + } + if (status != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E18888", "Inferring %s failed.", node_ptr->GetName().c_str()); + GELOGE(FAILED, "[Call][InferShapeAndType] Inferring %s failed.", node_ptr->GetName().c_str()); + return GRAPH_FAILED; + } + + for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); + auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc( + static_cast(out_anchor->GetIdx())); + TensorUtils::SetRealDimCnt(output_tensor, static_cast(output_tensor.GetShape().GetDims().size())); + (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(static_cast(out_anchor->GetIdx()), + output_tensor); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { + (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(static_cast(peer_anchor->GetIdx()), + output_tensor); + } + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus OperUtils::CopyGraph(const Graph &src_graph, Graph &dst_graph) { + std::string graph_name; + AscendString ascend_name; + if (dst_graph.GetName(ascend_name) == GRAPH_SUCCESS) { + graph_name = std::string((ascend_name.GetString() != nullptr) ? ascend_name.GetString() : ""); + } + if (graph_name.empty() && (src_graph.GetName(ascend_name) == GRAPH_SUCCESS)) { + graph_name = std::string((ascend_name.GetString() != nullptr) ? ascend_name.GetString() : ""); + } + + ComputeGraphPtr new_compute_graph = MakeShared(graph_name); + GE_CHECK_NOTNULL(new_compute_graph); + const ComputeGraphPtr src_compute_graph = GraphUtils::GetComputeGraph(src_graph); + GE_CHECK_NOTNULL(src_compute_graph); + if (src_compute_graph->GetParentGraph() != nullptr) { + GELOGE(GRAPH_FAILED, "[Check][RootGraph] Only support copy root graph, current graph name:%s, " + "parent graph name:%s.", src_compute_graph->GetName().c_str(), + src_compute_graph->GetParentGraph()->GetName().c_str()); + return GRAPH_FAILED; + } + const int32_t depth = 0; + std::map node_old_2_new; + std::map op_desc_old_2_new; + graphStatus ret = GraphUtils::CopyComputeGraph(src_compute_graph, new_compute_graph, + node_old_2_new, op_desc_old_2_new, depth); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Copy][Graph] failed, ret:%d.", ret); + return GRAPH_FAILED; + } + Graph tmp_graph = GraphUtils::CreateGraphFromComputeGraph(new_compute_graph); + ret = GraphUtils::CopyGraphImpl(src_graph, tmp_graph, node_old_2_new, op_desc_old_2_new); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Copy][GraphImpl] failed, ret:%d.", ret); + return GRAPH_FAILED; + } + std::swap(dst_graph, tmp_graph); + return GRAPH_SUCCESS; +} +} // namespace ge + diff --git a/inc/external/graph/graph.h b/inc/external/graph/graph.h index 05ed5ca1bf5d007fd77b63647c72ec46ebc9f1eb..b19f1048d0c99bea0db38015fa08d730296cc1fa 100644 --- a/inc/external/graph/graph.h +++ b/inc/external/graph/graph.h @@ -35,6 +35,7 @@ using GraphPtr = std::shared_ptr; /*lint -e148*/ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { friend class GraphUtils; + friend class OperUtils; public: ATTRIBUTED_DEPRECATED(Graph(const char_t *)) @@ -123,7 +124,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { static GraphPtr ConstructFromInputs(const std::vector &inputs, const AscendString &name); private: - GraphImplPtr impl_{nullptr}; }; } // namespace ge diff --git a/inc/external/graph/operator.h b/inc/external/graph/operator.h index 9a9b7ec73d810db80b67e3a9c5ca2378799b4a63..436029d72b44c6b4d562ec6a66973a5e7018d082 100644 --- a/inc/external/graph/operator.h +++ b/inc/external/graph/operator.h @@ -72,6 +72,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { friend class NodeUtils; friend class OpDescUtils; friend class GraphUtils; + friend class OperUtils; using OpInt = int64_t; using OpFloat = float32_t; diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index ef0e816c49a60b588050445dc18aab2397d72267..e2ba5b44fa33171c8e3fac24ef27a217adfaefe6 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -163,9 +163,9 @@ class ComputeGraph : public std::enable_shared_from_this, public A void AppendGraphOutNodes(const std::map> out_nodes_map); std::shared_ptr GetParentGraph(); - void SetParentGraph(const shared_ptr &parent); + void SetParentGraph(const std::shared_ptr &parent); std::shared_ptr GetParentNode(); - void SetParentNode(const shared_ptr &parent); + void SetParentNode(const std::shared_ptr &parent); const std::map> &GetGraphOutNodes() const; void SetOrigGraph(const ComputeGraphPtr orig_graph); @@ -245,7 +245,7 @@ class ComputeGraph : public std::enable_shared_from_this, public A graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::deque &stack); graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map); + std::map &breadth_node_map); graphStatus SortNodes(std::vector &stack, std::map &map_in_edge_num); Vistor AllGraphNodes(std::vector &subgraphs) const; diff --git a/inc/graph/detail/attributes_holder.h b/inc/graph/detail/attributes_holder.h index 7bff216195a85498065885d96f641a83a97d6be0..693d04eea9cf4cc58251ad05c464b50d7b0d77ec 100644 --- a/inc/graph/detail/attributes_holder.h +++ b/inc/graph/detail/attributes_holder.h @@ -206,6 +206,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { friend class ModelSerializeImp; friend class AttrUtils; friend class OpDescUtils; + friend class GraphUtils; std::vector required_attrs_; private: diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index 5d1017466f60a01c45ac036aa0ed38e69628eb72..04ed8e10fba3c5022c94e79481f05213549d4ae2 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -52,10 +52,9 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { using ConstOpDescPtr = std::shared_ptr; template - using Vistor = RangeVistor>; + using Vistor = RangeVistor>; friend class GraphBuilderImpl; - friend class OperatorImpl; OpDesc(const std::string &name, const std::string &type); @@ -106,7 +105,7 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { bool IsOptionalInput(const uint32_t index) const; - Vistor GetAllInputNames() const; + Vistor GetAllInputNames() const; GeTensorDescPtr MutableInputDesc(const uint32_t index) const; @@ -191,6 +190,9 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { graphStatus DefaultInferFormat(); std::function GetVerifyFunc() const; + std::function GetInferFormatFunc() const; + std::function GetInferDataSliceFunc() const; + std::function GetInferValueRangeFunc() const; void AddVerifierFunc(const std::function &func); @@ -308,10 +310,97 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { friend class GeAttrValueImp; friend class OnnxUtils; friend class GraphUtils; + friend class NodeUtils; }; using OpDescPtr = OpDesc::OpDescPtr; using ConstOpDescPtr = OpDesc::ConstOpDescPtr; using ConstOpDesc = const OpDesc; + +class OpDescBuilder { + public: + OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} + OpDescBuilder(const OpDescBuilder &) = delete; + OpDescBuilder &operator=(const OpDescBuilder &) = delete; + OpDescBuilder(const OpDescBuilder &&) = delete; + OpDescBuilder &operator=(const OpDescBuilder &&) = delete; + ~OpDescBuilder() = default; + + /** + * @brief Add input + * @param [in] name + * @return OpDescBuilder + */ + OpDescBuilder &AddInput(const std::string &name); + + /** + * @brief Add input + * @param [in] name + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddInput(const std::string &name, const GeTensorDesc &tensor); + + /** + * @brief Add dynamic input + * @param [in] name + * @param [in] num + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicInput(const std::string &name, const uint32_t num); + + /** + * @brief Add dynamic input + * @param [in] name + * @param [in] num + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor); + + /** + * @brief Add output + * @param [in] name + * @return OpDescBuilder + */ + OpDescBuilder &AddOutput(const std::string &name); + + /** + * @brief Add output + * @param [in] name + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddOutput(const std::string &name, const GeTensorDesc &tensor); + + /** + * @brief Add dynamic output + * @param [in] name + * @param [in] num + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicOutput(const std::string &name, const uint32_t num); + + /** + * @brief Add dynamic output + * @param [in] name + * @param [in] num + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicOutput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor); + + /** + * @brief Build op_desc + * @return OpDescPtr + */ + OpDescPtr Build(); + + private: + std::string name_; + std::string type_; + std::vector> inputs_; + std::vector> outputs_; +}; } // namespace ge #endif // INC_GRAPH_OP_DESC_H_ diff --git a/inc/graph/operator_factory_impl.h b/inc/graph/operator_factory_impl.h index 823803d0ce850e39421648ea515175a098662014..5e7a69982b2261e0e43f2eac31e1cddcca2a6991 100644 --- a/inc/graph/operator_factory_impl.h +++ b/inc/graph/operator_factory_impl.h @@ -37,6 +37,7 @@ struct InferValueRangePara { } friend class OpDescImpl; friend class InferValueRangePass; + friend class OperUtils; ~InferValueRangePara() = default; private: bool is_initialized = false; diff --git a/inc/graph/runtime_inference_context.h b/inc/graph/runtime_inference_context.h index 3d8772412c4b4686df32eb9d17b41412e0c17328..f2a5ace11e8973784463adbb29fd373b03999aae 100644 --- a/inc/graph/runtime_inference_context.h +++ b/inc/graph/runtime_inference_context.h @@ -23,7 +23,7 @@ #include #include "external/graph/ge_error_codes.h" #include "external/graph/tensor.h" -#include "ge_attr_value.h" +#include "graph/ge_attr_value.h" namespace ge { class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { diff --git a/inc/graph/utils/attr_utils.h b/inc/graph/utils/attr_utils.h index 4ebfb2f63d080b158575251ae0f03fe7eec14d55..f560e0f2ddd0cbcba76640afa0eb489b677e3220 100644 --- a/inc/graph/utils/attr_utils.h +++ b/inc/graph/utils/attr_utils.h @@ -118,8 +118,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { static bool GetDataType(ConstAttrHolderAdapter &&obj, const std::string &name, ge::DataType &value); static OpDescPtr CloneOpDesc(const ConstOpDescPtr &org_op_desc); - static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc); + static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); static std::map GetAllAttrs(ConstAttrHolderAdapter &&obj); static std::string GetAttrsStrAfterRid(ConstAttrHolderAdapter &&obj, const std::set &un_compute_attrs); diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index 15cca92928dd00b1ca3daa4b175818f6966bf6c2..9c4ef8e30037586a097c5c52ff12ef95357d1508 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -98,7 +98,7 @@ class GraphUtils { static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph); static graphStatus GetIndependentCompileGraphs(const ComputeGraphPtr &compute_graph, - std::vector &independent_compile_subgraphs); + std::vector &independent_compile_subgraphs); static graphStatus RecoverGraphOperators(const Graph &graph); @@ -158,6 +158,9 @@ class GraphUtils { std::unordered_map &all_new_nodes, const int32_t depth); + static OpDescPtr CloneOpDesc(const ConstOpDescPtr &org_op_desc); + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc); + static graphStatus CopyMembers(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph, const std::unordered_map &all_new_nodes); diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index 1262bb10e61b00162b4b10a302392d5d5c28ac72..34010146a04c9948aafdb31b7471d04ca85bcfbd 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -65,6 +65,9 @@ class NodeUtils { static bool IsConst(const Node &node); static void UnlinkAll(const Node &node); + static bool ClearInputDesc(const OpDescPtr &op_desc, const uint32_t index); + static bool ClearOutputDesc(const OpDescPtr &op_desc, const uint32_t index); + static graphStatus AppendInputAnchor(const NodePtr &node, const uint32_t num); static graphStatus RemoveInputAnchor(const NodePtr &node, const uint32_t num); diff --git a/inc/graph/utils/op_desc_utils.h b/inc/graph/utils/op_desc_utils.h index 6eeb2d56a1e7a950ccc4dc7ac384601bf1f34a6b..658eb4c6e382f569d3f83a7fd8fbcc3290ded85f 100644 --- a/inc/graph/utils/op_desc_utils.h +++ b/inc/graph/utils/op_desc_utils.h @@ -19,9 +19,11 @@ #include #include + #include "graph/def_types.h" #include "graph/node.h" #include "graph/runtime_inference_context.h" +#include "external/graph/operator.h" /*lint -e148*/ namespace ge { @@ -101,92 +103,6 @@ class OpDescUtils { static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); }; - -class OpDescBuilder { - public: - OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} - OpDescBuilder(const OpDescBuilder &) = delete; - OpDescBuilder &operator=(const OpDescBuilder &) = delete; - OpDescBuilder(const OpDescBuilder &&) = delete; - OpDescBuilder &operator=(const OpDescBuilder &&) = delete; - ~OpDescBuilder() = default; - - /// - /// @brief Add input - /// @param [in] name - /// @return OpDescBuilder - /// - OpDescBuilder& AddInput(const std::string &name); - - /// - /// @brief Add input - /// @param [in] name - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddInput(const std::string &name, const GeTensorDesc &tensor); - - /// - /// @brief Add dynamic input - /// @param [in] name - /// @param [in] num - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicInput(const std::string &name, const uint32_t num); - - /// - /// @brief Add dynamic input - /// @param [in] name - /// @param [in] num - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor); - - /// - /// @brief Add output - /// @param [in] name - /// @return OpDescBuilder - /// - OpDescBuilder& AddOutput(const std::string &name); - - /// - /// @brief Add output - /// @param [in] name - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddOutput(const std::string &name, const GeTensorDesc &tensor); - - /// - /// @brief Add dynamic output - /// @param [in] name - /// @param [in] num - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicOutput(const std::string &name, const uint32_t num); - - /// - /// @brief Add dynamic output - /// @param [in] name - /// @param [in] num - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicOutput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor); - - /// - /// @brief Build op_desc - /// @return OpDescPtr - /// - OpDescPtr Build(); - - private: - std::string name_; - std::string type_; - std::vector> inputs_; - std::vector> outputs_; -}; } // namespace ge /*lint +e148*/ #endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/inc/graph/utils/oper_utils.h b/inc/graph/utils/oper_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..8758db5034df9fc290361c75dd7471a2bb44a414 --- /dev/null +++ b/inc/graph/utils/oper_utils.h @@ -0,0 +1,65 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __INC_METADEF_OPER_UTILS_H +#define __INC_METADEF_OPER_UTILS_H + +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/compute_graph.h" +#include "external/graph/graph.h" + +namespace ge { +class OperUtils { + public: + // Detach from OpDesc + static graphStatus CallInferFunc(const OpDescPtr &op_desc, Operator &op); + static graphStatus CallInferFormatFunc(const OpDescPtr &op_desc, Operator &op); + static graphStatus CallInferValueRangeFunc(const OpDescPtr &op_desc, Operator &op); + static graphStatus OpVerify(const OpDescPtr &op_desc); + static graphStatus InferShapeAndType(const OpDescPtr &op_desc); + static graphStatus InferDataSlice(const OpDescPtr &op_desc); + + // Detach from Node + static graphStatus Verify(const NodePtr &node); + static graphStatus InferShapeAndType(const NodePtr &node); + static graphStatus InferOriginFormat(const NodePtr &node); + + // Detach from ComputeGraph + static graphStatus Verify(const ComputeGraphPtr &graph); + static graphStatus InferOriginFormat(const ComputeGraphPtr &graph); + static graphStatus InferShapeInNeed(const ComputeGraphPtr &graph); + + // Detach from NodeUtils + static ConstNodePtr GetNodeFromOperator(const Operator &op); + + // Detach from GraphUtils + static ComputeGraphPtr GetComputeGraph(const Graph &graph); + static ComputeGraphPtr CreateGraphFromOperator(const std::string &name, const std::vector &inputs); + static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); + static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph); + static void BreakConnect(const std::map &all_nodes_infos); + static graphStatus RecoverGraphOperators(const Graph &graph); + static graphStatus CopyGraph(const Graph &src_graph, Graph &dst_graph); + + private: + static graphStatus IsInputsValid(const NodePtr &node); + static graphStatus CopyGraphImpl(const Graph &src_graph, Graph &dst_graph, + const std::map &node_old_2_new, + const std::map &op_desc_old_2_new); +}; +} // namespace ge +#endif // __INC_METADEF_OPER_UTILS_H